diff --git a/builtin/logical/transit/backend.go b/builtin/logical/transit/backend.go index 3bb20d16d9..94d07cbd92 100644 --- a/builtin/logical/transit/backend.go +++ b/builtin/logical/transit/backend.go @@ -6,7 +6,7 @@ import ( ) func Factory(conf *logical.BackendConfig) (logical.Backend, error) { - b := Backend() + b := Backend(conf) be, err := b.Backend.Setup(conf) if err != nil { return nil, err @@ -15,7 +15,7 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) { return be, nil } -func Backend() *backend { +func Backend(conf *logical.BackendConfig) *backend { var b backend b.Backend = &framework.Backend{ Paths: []*framework.Path{ @@ -33,14 +33,12 @@ func Backend() *backend { Secrets: []*framework.Secret{}, } - b.policies = policyCache{ - cache: map[string]*lockingPolicy{}, - } + b.lm = newLockManager(conf.System.CachingDisabled()) return &b } type backend struct { *framework.Backend - policies policyCache + lm *lockManager } diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index dc0efaad41..3b31cc2bbf 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -559,16 +559,31 @@ func TestPolicyFuzzing(t *testing.T) { return } - be := Backend() + var be *backend + sysView := logical.TestSystemView() + be = Backend(&logical.BackendConfig{ + System: sysView, + }) + testPolicyFuzzingCommon(t, be) + + sysView.CachingDisabledVal = true + be = Backend(&logical.BackendConfig{ + System: sysView, + }) + testPolicyFuzzingCommon(t, be) +} + +func testPolicyFuzzingCommon(t *testing.T, be *backend) { storage := &logical.LockingInmemStorage{} wg := sync.WaitGroup{} funcs := []string{"encrypt", "decrypt", "rotate", "change_min_version"} + //keys := []string{"test1", "test2", "test3", "test4", "test5"} keys := []string{"test1", "test2", "test3"} // This is the goroutine loop - doFuzzy := func() { + doFuzzy := func(id int) { // Check for panics, otherwise notify we're done defer func() { if err := recover(); err != nil { @@ -587,6 +602,9 @@ func TestPolicyFuzzing(t *testing.T) { } fd := &framework.FieldData{} + var chosenFunc, chosenKey string + + //t.Errorf("Starting %d", id) for { // Stop after 10 seconds if time.Now().Sub(startTime) > 10*time.Second { @@ -594,8 +612,8 @@ func TestPolicyFuzzing(t *testing.T) { } // Pick a function and a key - chosenFunc := funcs[rand.Int()%len(funcs)] - chosenKey := keys[rand.Int()%len(keys)] + chosenFunc = funcs[rand.Int()%len(funcs)] + chosenKey = keys[rand.Int()%len(keys)] fd.Raw = map[string]interface{}{ "name": chosenKey, @@ -605,33 +623,33 @@ func TestPolicyFuzzing(t *testing.T) { // Try to write the key to make sure it exists _, err := be.pathPolicyWrite(req, fd) if err != nil { - t.Errorf("got an error: %v", err) - return + t.Fatalf("got an error: %v", err) } switch chosenFunc { // Encrypt our plaintext and store the result case "encrypt": + //t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id) fd.Raw["plaintext"] = base64.StdEncoding.EncodeToString([]byte(testPlaintext)) fd.Schema = be.pathEncrypt().Fields resp, err := be.pathEncryptWrite(req, fd) if err != nil { - t.Errorf("got an error: %v, resp is %#v", err, *resp) - return + t.Fatalf("got an error: %v, resp is %#v", err, *resp) } latestEncryptedText[chosenKey] = resp.Data["ciphertext"].(string) // Rotate to a new key version case "rotate": + //t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id) fd.Schema = be.pathRotate().Fields resp, err := be.pathRotateWrite(req, fd) if err != nil { - t.Errorf("got an error: %v, resp is %#v", err, *resp) - return + t.Fatalf("got an error: %v, resp is %#v, chosenKey is %s", err, *resp, chosenKey) } // Decrypt the ciphertext and compare the result case "decrypt": + //t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id) ct := latestEncryptedText[chosenKey] if ct == "" { continue @@ -645,13 +663,12 @@ func TestPolicyFuzzing(t *testing.T) { if resp.Data["error"].(string) == ErrTooOld { continue } - t.Errorf("got an error: %v, resp is %#v, ciphertext was %s", err, *resp, latestEncryptedText[chosenKey]) - return + t.Fatalf("got an error: %v, resp is %#v, ciphertext was %s, chosenKey is %s, id is %d", err, *resp, ct, chosenKey, id) } ptb64 := resp.Data["plaintext"].(string) pt, err := base64.StdEncoding.DecodeString(ptb64) if err != nil { - t.Errorf("got an error decoding base64 plaintext: %v", err) + t.Fatalf("got an error decoding base64 plaintext: %v", err) return } if string(pt) != testPlaintext { @@ -660,10 +677,10 @@ func TestPolicyFuzzing(t *testing.T) { // Change the min version, which also tests the archive functionality case "change_min_version": + //t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id) resp, err := be.pathPolicyRead(req, fd) if err != nil { - t.Errorf("got an error reading policy %s: %v", chosenKey, err) - return + t.Fatalf("got an error reading policy %s: %v", chosenKey, err) } latestVersion := resp.Data["latest_version"].(int) @@ -673,8 +690,7 @@ func TestPolicyFuzzing(t *testing.T) { fd.Schema = be.pathConfig().Fields resp, err = be.pathConfigWrite(req, fd) if err != nil { - t.Errorf("got an error setting min decryption version: %v", err) - return + t.Fatalf("got an error setting min decryption version: %v", err) } } } @@ -683,7 +699,7 @@ func TestPolicyFuzzing(t *testing.T) { // Spawn 1000 of these workers for 10 seconds for i := 0; i < 1000; i++ { wg.Add(1) - go doFuzzy() + go doFuzzy(i) } // Wait for them all to finish diff --git a/builtin/logical/transit/lock_manager.go b/builtin/logical/transit/lock_manager.go new file mode 100644 index 0000000000..515515726f --- /dev/null +++ b/builtin/logical/transit/lock_manager.go @@ -0,0 +1,325 @@ +package transit + +import ( + "encoding/json" + "errors" + "fmt" + "sync" + + "github.com/hashicorp/vault/logical" +) + +const ( + shared = false + exclusive = true +) + +var ( + errNeedExclusiveLock = errors.New("an exclusive lock is needed for this operation") +) + +type lockManager struct { + // A lock for each named key + locks map[string]*sync.RWMutex + + // A mutex for the map itself + locksMutex sync.RWMutex + + // If caching is enabled, the map of name to in-memory policy cache + cache map[string]*Policy + + // Used for global locking, and as the cache map mutex + cacheMutex sync.RWMutex +} + +func newLockManager(cacheDisabled bool) *lockManager { + lm := &lockManager{ + locks: map[string]*sync.RWMutex{}, + } + if !cacheDisabled { + lm.cache = map[string]*Policy{} + } + return lm +} + +func (lm *lockManager) CacheActive() bool { + return lm.cache != nil +} + +func (lm *lockManager) policyLock(name string, lockType bool) *sync.RWMutex { + lm.locksMutex.RLock() + lock := lm.locks[name] + if lock != nil { + // We want to give this up before locking the lock, but it's safe -- + // the only time we ever write to a value in this map is the first time + // we access the value, so it won't be changing out from under us + lm.locksMutex.RUnlock() + if lockType == exclusive { + lock.Lock() + } else { + lock.RLock() + } + return lock + } + + lm.locksMutex.RUnlock() + lm.locksMutex.Lock() + + // Don't defer the unlock call because if we get a valid lock below we want + // to release the lock mutex right away to avoid the possibility of + // deadlock by trying to grab the second lock + + // Check to make sure it hasn't been created since + lock = lm.locks[name] + if lock != nil { + lm.locksMutex.Unlock() + if lockType == exclusive { + lock.Lock() + } else { + lock.RLock() + } + return lock + } + + lock = &sync.RWMutex{} + lm.locks[name] = lock + lm.locksMutex.Unlock() + if lockType == exclusive { + lock.Lock() + } else { + lock.RLock() + } + + return lock +} + +func (lm *lockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) { + if lockType == exclusive { + lock.Unlock() + } else { + lock.RUnlock() + } +} + +// Get the policy with a read lock. If we get an error saying an exclusive lock +// is needed (for instance, for an upgrade/migration), give up the read lock, +// call again with an exclusive lock, then swap back out for a read lock. +func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) { + p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, shared) + if err == nil || + (err != nil && err != errNeedExclusiveLock) { + return p, lock, err + } + + // Try again while asking for an exlusive lock + p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, exclusive) + if err != nil || p == nil || lock == nil { + return p, lock, err + } + + lock.Unlock() + + p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, shared) + return p, lock, err +} + +// Get the policy with an exclusive lock +func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) { + p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, exclusive) + return p, lock, err +} + +// Get the policy with a read lock; if it returns that an exclusive lock is +// needed, retry. If successful, call one more time to get a read lock and +// return the value. +func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, derived bool) (*Policy, *sync.RWMutex, bool, error) { + p, lock, _, err := lm.getPolicyCommon(storage, name, true, derived, shared) + if err == nil || + (err != nil && err != errNeedExclusiveLock) { + return p, lock, false, err + } + + // Try again while asking for an exlusive lock + p, lock, upserted, err := lm.getPolicyCommon(storage, name, true, derived, exclusive) + if err != nil || p == nil || lock == nil { + return p, lock, upserted, err + } + + lock.Unlock() + + // Now get a shared lock for the return, but preserve the value of upsert + p, lock, _, err = lm.getPolicyCommon(storage, name, true, derived, shared) + + return p, lock, upserted, err +} + +// When the function returns, a lock will be held on the policy if err == nil. +// It is the caller's responsibility to unlock. +func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, upsert, derived, lockType bool) (*Policy, *sync.RWMutex, bool, error) { + lock := lm.policyLock(name, lockType) + + var p *Policy + var err error + + // Check if it's in our cache. If so, return right away. + if lm.CacheActive() { + lm.cacheMutex.RLock() + p = lm.cache[name] + if p != nil { + lm.cacheMutex.RUnlock() + return p, lock, false, nil + } + lm.cacheMutex.RUnlock() + } + + // Load it from storage + p, err = lm.getStoredPolicy(storage, name) + if err != nil { + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, err + } + + if p == nil { + // This is the only place we upsert a new policy, so if upsert is not + // specified, or the lock type is wrong, unllock before returning + if !upsert { + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, nil + } + + if lockType != exclusive { + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, errNeedExclusiveLock + } + + p = &Policy{ + Name: name, + CipherMode: "aes-gcm", + Derived: derived, + } + if derived { + p.KDFMode = kdfMode + } + + err = p.rotate(storage) + if err != nil { + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, err + } + + if lm.CacheActive() { + // Since we didn't have the policy in the cache, if there was no + // error, write the value in. + lm.cacheMutex.Lock() + defer lm.cacheMutex.Unlock() + // Make sure a policy didn't appear. If so, it will only be set if + // there was no error, so assume it's good and return that + exp := lm.cache[name] + if exp != nil { + return exp, lock, false, nil + } + if err == nil { + lm.cache[name] = p + } + } + + // We don't need to worry about upgrading since it will be a new policy + return p, lock, true, nil + } + + if p.needsUpgrade() { + if lockType == shared { + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, errNeedExclusiveLock + } + + err = p.upgrade(storage) + if err != nil { + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, err + } + } + + if lm.CacheActive() { + // Since we didn't have the policy in the cache, if there was no + // error, write the value in. + lm.cacheMutex.Lock() + defer lm.cacheMutex.Unlock() + // Make sure a policy didn't appear. If so, it will only be set if + // there was no error, so assume it's good and return that + exp := lm.cache[name] + if exp != nil { + return exp, lock, false, nil + } + if err == nil { + lm.cache[name] = p + } + } + + return p, lock, false, nil +} + +func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error { + lm.cacheMutex.Lock() + lock := lm.policyLock(name, exclusive) + defer lock.Unlock() + defer lm.cacheMutex.Unlock() + + var p *Policy + var err error + + if lm.CacheActive() { + p = lm.cache[name] + } + if p == nil { + p, err = lm.getStoredPolicy(storage, name) + if err != nil { + return err + } + if p == nil { + return fmt.Errorf("could not delete policy; not found") + } + } + + if !p.DeletionAllowed { + return fmt.Errorf("deletion is not allowed for this policy") + } + + err = storage.Delete("policy/" + name) + if err != nil { + return fmt.Errorf("error deleting policy %s: %s", name, err) + } + + err = storage.Delete("archive/" + name) + if err != nil { + return fmt.Errorf("error deleting archive %s: %s", name, err) + } + + if lm.CacheActive() { + delete(lm.cache, name) + } + + return nil +} + +func (lm *lockManager) getStoredPolicy(storage logical.Storage, name string) (*Policy, error) { + // Check if the policy already exists + raw, err := storage.Get("policy/" + name) + if err != nil { + return nil, err + } + if raw == nil { + return nil, nil + } + + // Decode the policy + policy := &Policy{ + Keys: KeyEntryMap{}, + } + err = json.Unmarshal(raw.Value, policy) + if err != nil { + return nil, err + } + + return policy, nil +} diff --git a/builtin/logical/transit/path_config.go b/builtin/logical/transit/path_config.go index 967b743412..5dc84c3f76 100644 --- a/builtin/logical/transit/path_config.go +++ b/builtin/logical/transit/path_config.go @@ -41,25 +41,20 @@ func (b *backend) pathConfigWrite( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - // Check if the policy already exists - lp, err := b.policies.getPolicy(req, name) + // Check if the policy already exists before we lock everything + p, lock, err := b.lm.GetPolicyExclusive(req.Storage, name) + if lock != nil { + defer lock.Unlock() + } if err != nil { return nil, err } - if lp == nil { + if p == nil { return logical.ErrorResponse( - fmt.Sprintf("no existing role named %s could be found", name)), + fmt.Sprintf("no existing key named %s could be found", name)), logical.ErrInvalidRequest } - lp.Lock() - defer lp.Unlock() - - // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { - return nil, fmt.Errorf("no existing role named %s could be found", name) - } - resp := &logical.Response{} persistNeeded := false @@ -78,12 +73,12 @@ func (b *backend) pathConfigWrite( } if minDecryptionVersion > 0 && - minDecryptionVersion != lp.policy.MinDecryptionVersion { - if minDecryptionVersion > lp.policy.LatestVersion { + minDecryptionVersion != p.MinDecryptionVersion { + if minDecryptionVersion > p.LatestVersion { return logical.ErrorResponse( - fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, lp.policy.LatestVersion)), nil + fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, p.LatestVersion)), nil } - lp.policy.MinDecryptionVersion = minDecryptionVersion + p.MinDecryptionVersion = minDecryptionVersion persistNeeded = true } } @@ -91,8 +86,8 @@ func (b *backend) pathConfigWrite( allowDeletionInt, ok := d.GetOk("deletion_allowed") if ok { allowDeletion := allowDeletionInt.(bool) - if allowDeletion != lp.policy.DeletionAllowed { - lp.policy.DeletionAllowed = allowDeletion + if allowDeletion != p.DeletionAllowed { + p.DeletionAllowed = allowDeletion persistNeeded = true } } @@ -100,8 +95,8 @@ func (b *backend) pathConfigWrite( // Add this as a guard here before persisting since we now require the min // decryption version to start at 1; even if it's not explicitly set here, // force the upgrade - if lp.policy.MinDecryptionVersion == 0 { - lp.policy.MinDecryptionVersion = 1 + if p.MinDecryptionVersion == 0 { + p.MinDecryptionVersion = 1 persistNeeded = true } @@ -109,7 +104,7 @@ func (b *backend) pathConfigWrite( return nil, nil } - return resp, lp.policy.Persist(req.Storage) + return resp, p.Persist(req.Storage) } const pathConfigHelpSyn = `Configure a named encryption key` diff --git a/builtin/logical/transit/path_datakey.go b/builtin/logical/transit/path_datakey.go index f2bec71d09..817529e9dd 100644 --- a/builtin/logical/transit/path_datakey.go +++ b/builtin/logical/transit/path_datakey.go @@ -73,24 +73,17 @@ func (b *backend) pathDatakeyWrite( } // Get the policy - lp, err := b.policies.getPolicy(req, name) + p, lock, err := b.lm.GetPolicyShared(req.Storage, name) + if lock != nil { + defer lock.RUnlock() + } if err != nil { return nil, err } - - // Error if invalid policy - if lp == nil { + if p == nil { return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - lp.RLock() - defer lp.RUnlock() - - // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { - return nil, fmt.Errorf("no existing policy named %s could be found", name) - } - newKey := make([]byte, 32) bits := d.Get("bits").(int) switch bits { @@ -107,7 +100,7 @@ func (b *backend) pathDatakeyWrite( return nil, err } - ciphertext, err := lp.policy.Encrypt(context, base64.StdEncoding.EncodeToString(newKey)) + ciphertext, err := p.Encrypt(context, base64.StdEncoding.EncodeToString(newKey)) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go index 66191b6295..254b2b5093 100644 --- a/builtin/logical/transit/path_decrypt.go +++ b/builtin/logical/transit/path_decrypt.go @@ -58,25 +58,18 @@ func (b *backend) pathDecryptWrite( } // Get the policy - lp, err := b.policies.getPolicy(req, name) + p, lock, err := b.lm.GetPolicyShared(req.Storage, name) + if lock != nil { + defer lock.RUnlock() + } if err != nil { return nil, err } - - // Error if invalid policy - if lp == nil { + if p == nil { return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - lp.RLock() - defer lp.RUnlock() - - // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { - return nil, fmt.Errorf("no existing policy named %s could be found", name) - } - - plaintext, err := lp.policy.Decrypt(context, ciphertext) + plaintext, err := p.Decrypt(context, ciphertext) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index fc2e2048f3..3494faac81 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -3,6 +3,7 @@ package transit import ( "encoding/base64" "fmt" + "sync" "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/logical" @@ -44,12 +45,14 @@ func (b *backend) pathEncrypt() *framework.Path { func (b *backend) pathEncryptExistenceCheck( req *logical.Request, d *framework.FieldData) (bool, error) { name := d.Get("name").(string) - lp, err := b.policies.getPolicy(req, name) + p, lock, err := b.lm.GetPolicyShared(req.Storage, name) + if lock != nil { + defer lock.RUnlock() + } if err != nil { return false, err } - - return lp != nil, nil + return p != nil, nil } func (b *backend) pathEncryptWrite( @@ -63,8 +66,8 @@ func (b *backend) pathEncryptWrite( // Decode the context if any contextRaw := d.Get("context").(string) var context []byte + var err error if len(contextRaw) != 0 { - var err error context, err = base64.StdEncoding.DecodeString(contextRaw) if err != nil { return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest @@ -72,37 +75,25 @@ func (b *backend) pathEncryptWrite( } // Get the policy - lp, err := b.policies.getPolicy(req, name) + var p *Policy + var lock *sync.RWMutex + var upserted bool + if req.Operation == logical.CreateOperation { + p, lock, upserted, err = b.lm.GetPolicyUpsert(req.Storage, name, len(context) != 0) + } else { + p, lock, err = b.lm.GetPolicyShared(req.Storage, name) + } + if lock != nil { + defer lock.RUnlock() + } if err != nil { return nil, err } - - // Error if invalid policy - if lp == nil { - if req.Operation != logical.CreateOperation { - return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest - } - - isDerived := len(context) != 0 - - lp, err = b.policies.generatePolicy(req.Storage, name, isDerived) - // If the error is that the policy has been created in the interim we - // will get the policy back, so only consider it an error if err is not - // nil and we do not get a policy back - if err != nil && lp != nil { - return nil, err - } + if p == nil { + return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - lp.RLock() - defer lp.RUnlock() - - // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { - return nil, fmt.Errorf("no existing policy named %s could be found", name) - } - - ciphertext, err := lp.policy.Encrypt(context, value) + ciphertext, err := p.Encrypt(context, value) if err != nil { switch err.(type) { case certutil.UserError: @@ -124,6 +115,11 @@ func (b *backend) pathEncryptWrite( "ciphertext": ciphertext, }, } + + if req.Operation == logical.CreateOperation && !upserted { + resp.AddWarning("Attempted creation of the key during the encrypt operation, but it was created beforehand") + } + return resp, nil } diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index 32587b6c60..466ca9b508 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -39,57 +39,57 @@ func (b *backend) pathPolicyWrite( name := d.Get("name").(string) derived := d.Get("derived").(bool) - // Check if the policy already exists - existing, err := b.policies.getPolicy(req, name) + p, lock, upserted, err := b.lm.GetPolicyUpsert(req.Storage, name, derived) + if lock != nil { + defer lock.RUnlock() + } if err != nil { return nil, err } - if existing != nil { - return nil, nil + if p == nil { + return nil, fmt.Errorf("error generating key: returned policy was nil") } - // Generate the policy - _, err = b.policies.generatePolicy(req.Storage, name, derived) - return nil, err + resp := &logical.Response{} + if !upserted { + resp.AddWarning(fmt.Sprintf("key %s already existed", name)) + } + + return nil, nil } func (b *backend) pathPolicyRead( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - lp, err := b.policies.getPolicy(req, name) + p, lock, err := b.lm.GetPolicyShared(req.Storage, name) + if lock != nil { + defer lock.RUnlock() + } if err != nil { return nil, err } - if lp == nil { + if p == nil { return nil, nil } - lp.RLock() - defer lp.RUnlock() - - // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { - return nil, fmt.Errorf("no existing policy named %s could be found", name) - } - // Return the response resp := &logical.Response{ Data: map[string]interface{}{ - "name": lp.policy.Name, - "cipher_mode": lp.policy.CipherMode, - "derived": lp.policy.Derived, - "deletion_allowed": lp.policy.DeletionAllowed, - "min_decryption_version": lp.policy.MinDecryptionVersion, - "latest_version": lp.policy.LatestVersion, + "name": p.Name, + "cipher_mode": p.CipherMode, + "derived": p.Derived, + "deletion_allowed": p.DeletionAllowed, + "min_decryption_version": p.MinDecryptionVersion, + "latest_version": p.LatestVersion, }, } - if lp.policy.Derived { - resp.Data["kdf_mode"] = lp.policy.KDFMode + if p.Derived { + resp.Data["kdf_mode"] = p.KDFMode } retKeys := map[string]int64{} - for k, v := range lp.policy.Keys { + for k, v := range p.Keys { retKeys[strconv.Itoa(k)] = v.CreationTime } resp.Data["keys"] = retKeys @@ -101,15 +101,8 @@ func (b *backend) pathPolicyDelete( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - lp, err := b.policies.getPolicy(req, name) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("error looking up policy %s, error is %s", name, err)), err - } - if lp == nil { - return logical.ErrorResponse(fmt.Sprintf("no such key %s", name)), logical.ErrInvalidRequest - } - - err = b.policies.deletePolicy(req.Storage, name) + // Delete does its own locking + err := b.lm.DeletePolicy(req.Storage, name) if err != nil { return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err } diff --git a/builtin/logical/transit/path_rewrap.go b/builtin/logical/transit/path_rewrap.go index d8d01bddca..a5854feeea 100644 --- a/builtin/logical/transit/path_rewrap.go +++ b/builtin/logical/transit/path_rewrap.go @@ -59,25 +59,19 @@ func (b *backend) pathRewrapWrite( } // Get the policy - lp, err := b.policies.getPolicy(req, name) + p, lock, err := b.lm.GetPolicyShared(req.Storage, name) + if lock != nil { + defer lock.RUnlock() + } if err != nil { return nil, err } - // Error if invalid policy - if lp == nil { + if p == nil { return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - lp.RLock() - defer lp.RUnlock() - - // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { - return nil, fmt.Errorf("no existing policy named %s could be found", name) - } - - plaintext, err := lp.policy.Decrypt(context, value) + plaintext, err := p.Decrypt(context, value) if err != nil { switch err.(type) { case certutil.UserError: @@ -93,7 +87,7 @@ func (b *backend) pathRewrapWrite( return nil, fmt.Errorf("empty plaintext returned during rewrap") } - ciphertext, err := lp.policy.Encrypt(context, plaintext) + ciphertext, err := p.Encrypt(context, plaintext) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_rotate.go b/builtin/logical/transit/path_rotate.go index 90bbc2e185..f10b78d56a 100644 --- a/builtin/logical/transit/path_rotate.go +++ b/builtin/logical/transit/path_rotate.go @@ -1,8 +1,6 @@ package transit import ( - "fmt" - "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -31,26 +29,19 @@ func (b *backend) pathRotateWrite( name := d.Get("name").(string) // Get the policy - lp, err := b.policies.getPolicy(req, name) + p, lock, err := b.lm.GetPolicyExclusive(req.Storage, name) + if lock != nil { + defer lock.Unlock() + } if err != nil { return nil, err } - - // Error if invalid policy - if lp == nil { - return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest + if p == nil { + return logical.ErrorResponse("key not found"), logical.ErrInvalidRequest } - lp.Lock() - defer lp.Unlock() - - // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { - return nil, fmt.Errorf("no existing policy named %s could be found", name) - } - - // Generate the policy - err = lp.policy.rotate(req.Storage) + // Rotate the policy + err = p.rotate(req.Storage) return nil, err } diff --git a/builtin/logical/transit/policy.go b/builtin/logical/transit/policy.go index 17d7cc00b3..c1b09e2d9a 100644 --- a/builtin/logical/transit/policy.go +++ b/builtin/logical/transit/policy.go @@ -9,7 +9,6 @@ import ( "fmt" "strconv" "strings" - "sync" "time" "github.com/hashicorp/vault/helper/certutil" @@ -24,197 +23,6 @@ const ( ErrTooOld = "ciphertext version is disallowed by policy (too old)" ) -// policyCache implements a simple locking cache of policies -type policyCache struct { - sync.RWMutex - cache map[string]*lockingPolicy -} - -// getPolicy loads a policy into the cache or returns one already in the cache -func (p *policyCache) getPolicy(req *logical.Request, name string) (*lockingPolicy, error) { - // We don't defer this since we may need to give it up and get a write lock - p.RLock() - - // First, see if we're in the cache -- if so, return that - if p.cache[name] != nil { - defer p.RUnlock() - return p.cache[name], nil - } - - // If we didn't find anything, we'll need to write into the cache, plus possibly - // persist the entry, so lock the cache - p.RUnlock() - p.Lock() - defer p.Unlock() - - // Check one more time to ensure that another process did not write during - // our lock switcheroo. - if p.cache[name] != nil { - return p.cache[name], nil - } - - // Note that we don't need to create the locking entry until the end, - // because the policy wasn't in the cache so we don't know about it, and we - // hold the cache lock so nothing else can be writing it in right now - - // Check if the policy already exists - raw, err := req.Storage.Get("policy/" + name) - if err != nil { - return nil, err - } - if raw == nil { - return nil, nil - } - - // Decode the policy - policy := &Policy{ - Keys: KeyEntryMap{}, - } - err = json.Unmarshal(raw.Value, policy) - if err != nil { - return nil, err - } - - persistNeeded := false - // Ensure we've moved from Key -> Keys - if policy.Key != nil && len(policy.Key) > 0 { - policy.migrateKeyToKeysMap() - persistNeeded = true - } - - // With archiving, past assumptions about the length of the keys map are no longer valid - if policy.LatestVersion == 0 && len(policy.Keys) != 0 { - policy.LatestVersion = len(policy.Keys) - persistNeeded = true - } - - // We disallow setting the version to 0, since they start at 1 since moving - // to rotate-able keys, so update if it's set to 0 - if policy.MinDecryptionVersion == 0 { - policy.MinDecryptionVersion = 1 - persistNeeded = true - } - - // On first load after an upgrade, copy keys to the archive - if policy.ArchiveVersion == 0 { - persistNeeded = true - } - - if persistNeeded { - err = policy.Persist(req.Storage) - if err != nil { - return nil, err - } - } - - lp := &lockingPolicy{ - policy: policy, - } - p.cache[name] = lp - - return lp, nil -} - -// generatePolicy is used to create a new named policy with a randomly -// generated key -func (p *policyCache) generatePolicy(storage logical.Storage, name string, derived bool) (*lockingPolicy, error) { - // Ensure one with this name doesn't already exist - lp, err := p.getPolicy(&logical.Request{ - Storage: storage, - }, name) - if err != nil { - return nil, fmt.Errorf("error checking if policy already exists: %s", err) - } - if lp != nil { - return nil, fmt.Errorf("policy %s already exists", name) - } - - p.Lock() - defer p.Unlock() - - // Now we need to check again in the cache to ensure the policy wasn't - // created since we checked getPolicy. A policy being created holds a write - // lock until it's done, so it'll be in the cache at this point. - if lp := p.cache[name]; lp != nil { - return nil, fmt.Errorf("policy %s already exists", name) - } - - // Create the policy object - policy := &Policy{ - Name: name, - CipherMode: "aes-gcm", - Derived: derived, - } - if derived { - policy.KDFMode = kdfMode - } - - err = policy.rotate(storage) - if err != nil { - return nil, err - } - - lp = &lockingPolicy{ - policy: policy, - } - p.cache[name] = lp - - // Return the policy - return lp, nil -} - -// deletePolicy deletes a policy -func (p *policyCache) deletePolicy(storage logical.Storage, name string) error { - // Ensure one with this name exists - lp, err := p.getPolicy(&logical.Request{ - Storage: storage, - }, name) - if err != nil { - return fmt.Errorf("error checking if policy already exists: %s", err) - } - if lp == nil { - return fmt.Errorf("policy %s does not exist", name) - } - - p.Lock() - defer p.Unlock() - - lp = p.cache[name] - if lp == nil { - return fmt.Errorf("policy %s not found", name) - } - - // We need to ensure all other access has stopped - lp.Lock() - defer lp.Unlock() - - // Verify this hasn't changed - if !lp.policy.DeletionAllowed { - return fmt.Errorf("deletion not allowed for policy %s", name) - } - - err = storage.Delete("policy/" + name) - if err != nil { - return fmt.Errorf("error deleting policy %s: %s", name, err) - } - - err = storage.Delete("archive/" + name) - if err != nil { - return fmt.Errorf("error deleting archive %s: %s", name, err) - } - - lp.policy = nil - delete(p.cache, name) - - return nil -} - -// lockingPolicy holds a Policy guarded by a lock -type lockingPolicy struct { - sync.RWMutex - policy *Policy -} - // KeyEntry stores the key and metadata type KeyEntry struct { Key []byte `json:"key"` @@ -427,6 +235,67 @@ func (p *Policy) Serialize() ([]byte, error) { return json.Marshal(p) } +func (p *Policy) needsUpgrade() bool { + // Ensure we've moved from Key -> Keys + if p.Key != nil && len(p.Key) > 0 { + return true + } + + // With archiving, past assumptions about the length of the keys map are no longer valid + if p.LatestVersion == 0 && len(p.Keys) != 0 { + return true + } + + // We disallow setting the version to 0, since they start at 1 since moving + // to rotate-able keys, so update if it's set to 0 + if p.MinDecryptionVersion == 0 { + return true + } + + // On first load after an upgrade, copy keys to the archive + if p.ArchiveVersion == 0 { + return true + } + + return false +} + +func (p *Policy) upgrade(storage logical.Storage) error { + persistNeeded := false + // Ensure we've moved from Key -> Keys + if p.Key != nil && len(p.Key) > 0 { + p.migrateKeyToKeysMap() + persistNeeded = true + } + + // With archiving, past assumptions about the length of the keys map are no longer valid + if p.LatestVersion == 0 && len(p.Keys) != 0 { + p.LatestVersion = len(p.Keys) + persistNeeded = true + } + + // We disallow setting the version to 0, since they start at 1 since moving + // to rotate-able keys, so update if it's set to 0 + if p.MinDecryptionVersion == 0 { + p.MinDecryptionVersion = 1 + persistNeeded = true + } + + // On first load after an upgrade, copy keys to the archive + if p.ArchiveVersion == 0 { + persistNeeded = true + } + + if persistNeeded { + err := p.Persist(storage) + if err != nil { + return err + } + } + + return nil +} + // DeriveKey is used to derive the encryption key that should // be used depending on the policy. If derivation is disabled the // raw key is used and no context is required, otherwise the KDF @@ -521,17 +390,17 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) { func (p *Policy) Decrypt(context []byte, value string) (string, error) { // Verify the prefix if !strings.HasPrefix(value, "vault:v") { - return "", certutil.UserError{Err: "invalid ciphertext"} + return "", certutil.UserError{Err: "invalid ciphertext: no prefix"} } splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, "vault:v"), ":", 2) if len(splitVerCiphertext) != 2 { - return "", certutil.UserError{Err: "invalid ciphertext"} + return "", certutil.UserError{Err: "invalid ciphertext: wrong number of fields"} } ver, err := strconv.Atoi(splitVerCiphertext[0]) if err != nil { - return "", certutil.UserError{Err: "invalid ciphertext"} + return "", certutil.UserError{Err: "invalid ciphertext: version number could not be decoded"} } if ver == 0 { @@ -540,6 +409,10 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) { ver = 1 } + if ver > p.LatestVersion { + return "", certutil.UserError{Err: "invalid ciphertext: version is too new"} + } + if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion { return "", certutil.UserError{Err: ErrTooOld} } @@ -560,7 +433,7 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) { // Decode the base64 decoded, err := base64.StdEncoding.DecodeString(splitVerCiphertext[1]) if err != nil { - return "", certutil.UserError{Err: "invalid ciphertext"} + return "", certutil.UserError{Err: "invalid ciphertext: could not decode base64"} } // Setup the cipher @@ -582,7 +455,7 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) { // Verify and Decrypt plain, err := gcm.Open(nil, nonce, ciphertext, nil) if err != nil { - return "", certutil.UserError{Err: "invalid ciphertext"} + return "", certutil.UserError{Err: "invalid ciphertext: unable to decrypt"} } return base64.StdEncoding.EncodeToString(plain), nil @@ -617,6 +490,8 @@ func (p *Policy) rotate(storage logical.Storage) error { p.MinDecryptionVersion = 1 } + //fmt.Printf("policy %s rotated to %d\n", p.Name, p.LatestVersion) + return p.Persist(storage) } diff --git a/builtin/logical/transit/policy_test.go b/builtin/logical/transit/policy_test.go index 04b3c3bbb1..f326416208 100644 --- a/builtin/logical/transit/policy_test.go +++ b/builtin/logical/transit/policy_test.go @@ -16,38 +16,49 @@ func resetKeysArchive() { } func Test_KeyUpgrade(t *testing.T) { + testKeyUpgradeCommon(t, newLockManager(false)) + testKeyUpgradeCommon(t, newLockManager(true)) +} + +func testKeyUpgradeCommon(t *testing.T, lm *lockManager) { storage := &logical.InmemStorage{} - policies := &policyCache{ - cache: map[string]*lockingPolicy{}, + p, lock, upserted, err := lm.GetPolicyUpsert(storage, "test", false) + if lock != nil { + defer lock.RUnlock() } - lp, err := policies.generatePolicy(storage, "test", false) if err != nil { t.Fatal(err) } - if lp == nil { + if p == nil { t.Fatal("nil policy") } + if !upserted { + t.Fatal("expected an upsert") + } - policy := lp.policy + testBytes := make([]byte, len(p.Keys[1].Key)) + copy(testBytes, p.Keys[1].Key) - testBytes := make([]byte, len(policy.Keys[1].Key)) - copy(testBytes, policy.Keys[1].Key) - - policy.Key = policy.Keys[1].Key - policy.Keys = nil - policy.migrateKeyToKeysMap() - if policy.Key != nil { + p.Key = p.Keys[1].Key + p.Keys = nil + p.migrateKeyToKeysMap() + if p.Key != nil { t.Fatal("policy.Key is not nil") } - if len(policy.Keys) != 1 { + if len(p.Keys) != 1 { t.Fatal("policy.Keys is the wrong size") } - if !reflect.DeepEqual(testBytes, policy.Keys[1].Key) { + if !reflect.DeepEqual(testBytes, p.Keys[1].Key) { t.Fatal("key mismatch") } } func Test_ArchivingUpgrade(t *testing.T) { + testArchivingUpgradeCommon(t, newLockManager(false)) + testArchivingUpgradeCommon(t, newLockManager(true)) +} + +func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) { resetKeysArchive() // First, we generate a policy and rotate it a number of times. Each time @@ -56,31 +67,27 @@ func Test_ArchivingUpgrade(t *testing.T) { // zero and latest, respectively storage := &logical.InmemStorage{} - policies := &policyCache{ - cache: map[string]*lockingPolicy{}, - } - lp, err := policies.generatePolicy(storage, "test", false) + p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false) if err != nil { t.Fatal(err) } - if lp == nil { - t.Fatal("policy is nil") + if p == nil || lock == nil { + t.Fatal("nil policy or lock") } - - policy := lp.policy + lock.RUnlock() // Store the initial key in the archive - keysArchive = append(keysArchive, policy.Keys[1]) - checkKeys(t, policy, storage, "initial", 1, 1, 1) + keysArchive = append(keysArchive, p.Keys[1]) + checkKeys(t, p, storage, "initial", 1, 1, 1) for i := 2; i <= 10; i++ { - err = policy.rotate(storage) + err = p.rotate(storage) if err != nil { t.Fatal(err) } - keysArchive = append(keysArchive, policy.Keys[i]) - checkKeys(t, policy, storage, "rotate", i, i, i) + keysArchive = append(keysArchive, p.Keys[i]) + checkKeys(t, p, storage, "rotate", i, i, i) } // Now, wipe the archive and set the archive version to zero @@ -88,44 +95,100 @@ func Test_ArchivingUpgrade(t *testing.T) { if err != nil { t.Fatal(err) } - policy.ArchiveVersion = 0 + p.ArchiveVersion = 0 // Store it, but without calling persist, so we don't trigger // handleArchiving() - buf, err := policy.Serialize() + buf, err := p.Serialize() if err != nil { t.Fatal(err) } // Write the policy into storage err = storage.Put(&logical.StorageEntry{ - Key: "policy/" + policy.Name, + Key: "policy/" + p.Name, Value: buf, }) if err != nil { t.Fatal(err) } - // Expire from the cache since we modified it under-the-hood - delete(policies.cache, "test") + // If we're caching, expire from the cache since we modified it + // under-the-hood + if lm.CacheActive() { + delete(lm.cache, "test") + } // Now get the policy again; the upgrade should happen automatically - lp, err = policies.getPolicy(&logical.Request{ - Storage: storage, - }, "test") + p, lock, err = lm.GetPolicyShared(storage, "test") if err != nil { t.Fatal(err) } - if lp == nil { - t.Fatal("policy is nil") + if p == nil || lock == nil { + t.Fatal("nil policy or lock") + } + lock.RUnlock() + + checkKeys(t, p, storage, "upgrade", 10, 10, 10) + + // Let's check some deletion logic while we're at it + + // The policy should be in there + if lm.CacheActive() && lm.cache["test"] == nil { + t.Fatal("nil policy in cache") } - policy = lp.policy + // First we'll do this wrong, by not setting the deletion flag + err = lm.DeletePolicy(storage, "test") + if err == nil { + t.Fatal("got nil error, but should not have been able to delete since we didn't set the deletion flag on the policy") + } - checkKeys(t, policy, storage, "upgrade", 10, 10, 10) + // The policy should still be in there + if lm.CacheActive() && lm.cache["test"] == nil { + t.Fatal("nil policy in cache") + } + + p, lock, err = lm.GetPolicyShared(storage, "test") + if err != nil { + t.Fatal(err) + } + if p == nil || lock == nil { + t.Fatal("policy or lock nil after bad delete") + } + lock.RUnlock() + + // Now do it properly + p.DeletionAllowed = true + err = p.Persist(storage) + if err != nil { + t.Fatal(err) + } + err = lm.DeletePolicy(storage, "test") + if err != nil { + t.Fatal(err) + } + + // The policy should *not* be in there + if lm.CacheActive() && lm.cache["test"] != nil { + t.Fatal("non-nil policy in cache") + } + + p, lock, err = lm.GetPolicyShared(storage, "test") + if err != nil { + t.Fatal(err) + } + if p != nil || lock != nil { + t.Fatal("policy or lock not nil after delete") + } } func Test_Archiving(t *testing.T) { + testArchivingCommon(t, newLockManager(false)) + testArchivingCommon(t, newLockManager(true)) +} + +func testArchivingCommon(t *testing.T, lm *lockManager) { resetKeysArchive() // First, we generate a policy and rotate it a number of times. Each time @@ -135,38 +198,35 @@ func Test_Archiving(t *testing.T) { storage := &logical.InmemStorage{} - policies := &policyCache{ - cache: map[string]*lockingPolicy{}, + p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false) + if lock != nil { + defer lock.RUnlock() } - - lp, err := policies.generatePolicy(storage, "test", false) if err != nil { t.Fatal(err) } - if lp == nil { - t.Fatal("policy is nil") + if p == nil { + t.Fatal("nil policy") } - policy := lp.policy - // Store the initial key in the archive - keysArchive = append(keysArchive, policy.Keys[1]) - checkKeys(t, policy, storage, "initial", 1, 1, 1) + keysArchive = append(keysArchive, p.Keys[1]) + checkKeys(t, p, storage, "initial", 1, 1, 1) for i := 2; i <= 10; i++ { - err = policy.rotate(storage) + err = p.rotate(storage) if err != nil { t.Fatal(err) } - keysArchive = append(keysArchive, policy.Keys[i]) - checkKeys(t, policy, storage, "rotate", i, i, i) + keysArchive = append(keysArchive, p.Keys[i]) + checkKeys(t, p, storage, "rotate", i, i, i) } // Move the min decryption version up for i := 1; i <= 10; i++ { - policy.MinDecryptionVersion = i + p.MinDecryptionVersion = i - err = policy.Persist(storage) + err = p.Persist(storage) if err != nil { t.Fatal(err) } @@ -178,14 +238,14 @@ func Test_Archiving(t *testing.T) { // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min // decryption version plus 1 (the min decryption version key // itself) - checkKeys(t, policy, storage, "minadd", 10, 10, policy.LatestVersion-policy.MinDecryptionVersion+1) + checkKeys(t, p, storage, "minadd", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1) } // Move the min decryption version down for i := 10; i >= 1; i-- { - policy.MinDecryptionVersion = i + p.MinDecryptionVersion = i - err = policy.Persist(storage) + err = p.Persist(storage) if err != nil { t.Fatal(err) } @@ -197,7 +257,7 @@ func Test_Archiving(t *testing.T) { // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min // decryption version plus 1 (the min decryption version key // itself) - checkKeys(t, policy, storage, "minsub", 10, 10, policy.LatestVersion-policy.MinDecryptionVersion+1) + checkKeys(t, p, storage, "minsub", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1) } } diff --git a/logical/system_view.go b/logical/system_view.go index 33dd01a414..d20bf0c373 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -26,6 +26,10 @@ type SystemView interface { // when the stored CRL will be removed during the unmounting process // anyways), we can ignore the errors to allow unmounting to complete. Tainted() bool + + // Returns true if caching is disabled. If true, no caches should be used, + // despite known slowdowns. + CachingDisabled() bool } type StaticSystemView struct { @@ -33,6 +37,7 @@ type StaticSystemView struct { MaxLeaseTTLVal time.Duration SudoPrivilegeVal bool TaintedVal bool + CachingDisabledVal bool } func (d StaticSystemView) DefaultLeaseTTL() time.Duration { @@ -50,3 +55,7 @@ func (d StaticSystemView) SudoPrivilege(path string, token string) bool { func (d StaticSystemView) Tainted() bool { return d.TaintedVal } + +func (d StaticSystemView) CachingDisabled() bool { + return d.CachingDisabledVal +} diff --git a/vault/core.go b/vault/core.go index 4c209e0665..44dcd298e2 100644 --- a/vault/core.go +++ b/vault/core.go @@ -218,6 +218,9 @@ type Core struct { maxLeaseTTL time.Duration logger *log.Logger + + // cachingDisabled indicates whether caches are disabled + cachingDisabled bool } // CoreConfig is used to parameterize a core @@ -315,6 +318,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { logger: conf.Logger, defaultLeaseTTL: conf.DefaultLeaseTTL, maxLeaseTTL: conf.MaxLeaseTTL, + cachingDisabled: conf.DisableCache, } // Setup the backends diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index 9c9340ac9a..8dc806de62 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -69,3 +69,8 @@ func (d dynamicSystemView) fetchTTLs() (def, max time.Duration) { func (d dynamicSystemView) Tainted() bool { return d.mountEntry.Tainted } + +// CachingDisabled indicates whether to use caching behavior +func (d dynamicSystemView) CachingDisabled() bool { + return d.core.cachingDisabled +} diff --git a/vault/policy_store.go b/vault/policy_store.go index 8bbb79a8de..90b6e23416 100644 --- a/vault/policy_store.go +++ b/vault/policy_store.go @@ -34,12 +34,15 @@ type PolicyEntry struct { // NewPolicyStore creates a new PolicyStore that is backed // using a given view. It used used to durable store and manage named policy. -func NewPolicyStore(view *BarrierView) *PolicyStore { - cache, _ := lru.New2Q(policyCacheSize) +func NewPolicyStore(view *BarrierView, system logical.SystemView) *PolicyStore { p := &PolicyStore{ view: view, - lru: cache, } + if !system.CachingDisabled() { + cache, _ := lru.New2Q(policyCacheSize) + p.lru = cache + } + return p } @@ -50,7 +53,7 @@ func (c *Core) setupPolicyStore() error { view := c.systemBarrierView.SubView(policySubPath) // Create the policy store - c.policyStore = NewPolicyStore(view) + c.policyStore = NewPolicyStore(view, &dynamicSystemView{core: c}) // Ensure that the default policy exists, and if not, create it policy, err := c.policyStore.GetPolicy("default") @@ -95,23 +98,29 @@ func (ps *PolicyStore) SetPolicy(p *Policy) error { return fmt.Errorf("failed to persist policy: %v", err) } - // Update the LRU cache - ps.lru.Add(p.Name, p) + if ps.lru != nil { + // Update the LRU cache + ps.lru.Add(p.Name, p) + } return nil } // GetPolicy is used to fetch the named policy func (ps *PolicyStore) GetPolicy(name string) (*Policy, error) { defer metrics.MeasureSince([]string{"policy", "get_policy"}, time.Now()) - // Check for cached policy - if raw, ok := ps.lru.Get(name); ok { - return raw.(*Policy), nil + if ps.lru != nil { + // Check for cached policy + if raw, ok := ps.lru.Get(name); ok { + return raw.(*Policy), nil + } } // Special case the root policy if name == "root" { p := &Policy{Name: "root"} - ps.lru.Add(p.Name, p) + if ps.lru != nil { + ps.lru.Add(p.Name, p) + } return p, nil } @@ -152,8 +161,11 @@ func (ps *PolicyStore) GetPolicy(name string) (*Policy, error) { policy = p } - // Update the LRU cache - ps.lru.Add(name, policy) + if ps.lru != nil { + // Update the LRU cache + ps.lru.Add(name, policy) + } + return policy, nil } @@ -178,8 +190,10 @@ func (ps *PolicyStore) DeletePolicy(name string) error { return fmt.Errorf("failed to delete policy: %v", err) } - // Clear the cache - ps.lru.Remove(name) + if ps.lru != nil { + // Clear the cache + ps.lru.Remove(name) + } return nil } diff --git a/vault/policy_store_test.go b/vault/policy_store_test.go index 4e4b8fe0f9..05cbd1c79e 100644 --- a/vault/policy_store_test.go +++ b/vault/policy_store_test.go @@ -10,7 +10,16 @@ import ( func mockPolicyStore(t *testing.T) *PolicyStore { _, barrier, _ := mockBarrier(t) view := NewBarrierView(barrier, "foo/") - p := NewPolicyStore(view) + p := NewPolicyStore(view, logical.TestSystemView()) + return p +} + +func mockPolicyStoreNoCache(t *testing.T) *PolicyStore { + sysView := logical.TestSystemView() + sysView.CachingDisabledVal = true + _, barrier, _ := mockBarrier(t) + view := NewBarrierView(barrier, "foo/") + p := NewPolicyStore(view, sysView) return p } @@ -44,7 +53,13 @@ func TestPolicyStore_Root(t *testing.T) { func TestPolicyStore_CRUD(t *testing.T) { ps := mockPolicyStore(t) + testPolicyStore_CRUD(t, ps) + ps = mockPolicyStoreNoCache(t) + testPolicyStore_CRUD(t, ps) +} + +func testPolicyStore_CRUD(t *testing.T, ps *PolicyStore) { // Get should return nothing p, err := ps.GetPolicy("dev") if err != nil { diff --git a/website/source/docs/config/index.html.md b/website/source/docs/config/index.html.md index 7b7f0b2cfb..c428605558 100644 --- a/website/source/docs/config/index.html.md +++ b/website/source/docs/config/index.html.md @@ -50,9 +50,9 @@ sending a SIGHUP to the server process. These are denoted below. "tcp" is currently the only option available. A full reference for the inner syntax is below. -* `disable_cache` (optional) - A boolean. If true, this will disable the - read cache used by the physical storage subsystem. This will very - significantly impact performance. +* `disable_cache` (optional) - A boolean. If true, this will disable all caches + within Vault, including the read cache used by the physical storage + subsystem. This will very significantly impact performance. * `disable_mlock` (optional) - A boolean. If true, this will disable the server from executing the `mlock` syscall to prevent memory from being