From b18854be70cb1eb794fd1da7577410ef1baff947 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 21 Apr 2016 13:52:42 +0000 Subject: [PATCH 01/11] Plumb disabling caches through the policy store --- logical/system_view.go | 9 +++++ vault/core.go | 4 ++ vault/dynamic_system_view.go | 5 +++ vault/policy_store.go | 50 ++++++++++++++++-------- vault/policy_store_test.go | 17 +++++++- website/source/docs/config/index.html.md | 6 +-- 6 files changed, 70 insertions(+), 21 deletions(-) diff --git a/logical/system_view.go b/logical/system_view.go index 33dd01a414..4e26300cb0 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -26,6 +26,10 @@ type SystemView interface { // when the stored CRL will be removed during the unmounting process // anyways), we can ignore the errors to allow unmounting to complete. Tainted() bool + + // Returns true if caching is disabled. If true, no caches should be used, + // despite known slowdowns. + CacheDisabled() bool } type StaticSystemView struct { @@ -33,6 +37,7 @@ type StaticSystemView struct { MaxLeaseTTLVal time.Duration SudoPrivilegeVal bool TaintedVal bool + CacheDisabledVal bool } func (d StaticSystemView) DefaultLeaseTTL() time.Duration { @@ -50,3 +55,7 @@ func (d StaticSystemView) SudoPrivilege(path string, token string) bool { func (d StaticSystemView) Tainted() bool { return d.TaintedVal } + +func (d StaticSystemView) CacheDisabled() bool { + return d.CacheDisabledVal +} diff --git a/vault/core.go b/vault/core.go index 4c209e0665..12ceefb923 100644 --- a/vault/core.go +++ b/vault/core.go @@ -218,6 +218,9 @@ type Core struct { maxLeaseTTL time.Duration logger *log.Logger + + // cacheDisabled indicates whether caches are disabled + cacheDisabled bool } // CoreConfig is used to parameterize a core @@ -315,6 +318,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { logger: conf.Logger, defaultLeaseTTL: conf.DefaultLeaseTTL, maxLeaseTTL: conf.MaxLeaseTTL, + cacheDisabled: conf.DisableCache, } // Setup the backends diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index 9c9340ac9a..b4ef6a77ae 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -69,3 +69,8 @@ func (d dynamicSystemView) fetchTTLs() (def, max time.Duration) { func (d dynamicSystemView) Tainted() bool { return d.mountEntry.Tainted } + +// CacheDisabled indicates whether to use caching behavior +func (d dynamicSystemView) CacheDisabled() bool { + return d.core.cacheDisabled +} diff --git a/vault/policy_store.go b/vault/policy_store.go index 8bbb79a8de..d4729f740b 100644 --- a/vault/policy_store.go +++ b/vault/policy_store.go @@ -22,8 +22,9 @@ const ( // PolicyStore is used to provide durable storage of policy, and to // manage ACLs associated with them. type PolicyStore struct { - view *BarrierView - lru *lru.TwoQueueCache + view *BarrierView + lru *lru.TwoQueueCache + system logical.SystemView } // PolicyEntry is used to store a policy by name @@ -34,12 +35,16 @@ type PolicyEntry struct { // NewPolicyStore creates a new PolicyStore that is backed // using a given view. It used used to durable store and manage named policy. -func NewPolicyStore(view *BarrierView) *PolicyStore { - cache, _ := lru.New2Q(policyCacheSize) +func NewPolicyStore(view *BarrierView, system logical.SystemView) *PolicyStore { p := &PolicyStore{ - view: view, - lru: cache, + view: view, + system: system, } + if !system.CacheDisabled() { + cache, _ := lru.New2Q(policyCacheSize) + p.lru = cache + } + return p } @@ -50,7 +55,7 @@ func (c *Core) setupPolicyStore() error { view := c.systemBarrierView.SubView(policySubPath) // Create the policy store - c.policyStore = NewPolicyStore(view) + c.policyStore = NewPolicyStore(view, &dynamicSystemView{core: c}) // Ensure that the default policy exists, and if not, create it policy, err := c.policyStore.GetPolicy("default") @@ -95,23 +100,29 @@ func (ps *PolicyStore) SetPolicy(p *Policy) error { return fmt.Errorf("failed to persist policy: %v", err) } - // Update the LRU cache - ps.lru.Add(p.Name, p) + if !ps.system.CacheDisabled() { + // Update the LRU cache + ps.lru.Add(p.Name, p) + } return nil } // GetPolicy is used to fetch the named policy func (ps *PolicyStore) GetPolicy(name string) (*Policy, error) { defer metrics.MeasureSince([]string{"policy", "get_policy"}, time.Now()) - // Check for cached policy - if raw, ok := ps.lru.Get(name); ok { - return raw.(*Policy), nil + if !ps.system.CacheDisabled() { + // Check for cached policy + if raw, ok := ps.lru.Get(name); ok { + return raw.(*Policy), nil + } } // Special case the root policy if name == "root" { p := &Policy{Name: "root"} - ps.lru.Add(p.Name, p) + if !ps.system.CacheDisabled() { + ps.lru.Add(p.Name, p) + } return p, nil } @@ -152,8 +163,11 @@ func (ps *PolicyStore) GetPolicy(name string) (*Policy, error) { policy = p } - // Update the LRU cache - ps.lru.Add(name, policy) + if !ps.system.CacheDisabled() { + // Update the LRU cache + ps.lru.Add(name, policy) + } + return policy, nil } @@ -178,8 +192,10 @@ func (ps *PolicyStore) DeletePolicy(name string) error { return fmt.Errorf("failed to delete policy: %v", err) } - // Clear the cache - ps.lru.Remove(name) + if !ps.system.CacheDisabled() { + // Clear the cache + ps.lru.Remove(name) + } return nil } diff --git a/vault/policy_store_test.go b/vault/policy_store_test.go index 4e4b8fe0f9..456bc8375b 100644 --- a/vault/policy_store_test.go +++ b/vault/policy_store_test.go @@ -10,7 +10,16 @@ import ( func mockPolicyStore(t *testing.T) *PolicyStore { _, barrier, _ := mockBarrier(t) view := NewBarrierView(barrier, "foo/") - p := NewPolicyStore(view) + p := NewPolicyStore(view, logical.TestSystemView()) + return p +} + +func mockPolicyStoreNoCache(t *testing.T) *PolicyStore { + sysView := logical.TestSystemView() + sysView.CacheDisabledVal = true + _, barrier, _ := mockBarrier(t) + view := NewBarrierView(barrier, "foo/") + p := NewPolicyStore(view, sysView) return p } @@ -44,7 +53,13 @@ func TestPolicyStore_Root(t *testing.T) { func TestPolicyStore_CRUD(t *testing.T) { ps := mockPolicyStore(t) + testPolicyStore_CRUD(t, ps) + ps = mockPolicyStoreNoCache(t) + testPolicyStore_CRUD(t, ps) +} + +func testPolicyStore_CRUD(t *testing.T, ps *PolicyStore) { // Get should return nothing p, err := ps.GetPolicy("dev") if err != nil { diff --git a/website/source/docs/config/index.html.md b/website/source/docs/config/index.html.md index 7b7f0b2cfb..c428605558 100644 --- a/website/source/docs/config/index.html.md +++ b/website/source/docs/config/index.html.md @@ -50,9 +50,9 @@ sending a SIGHUP to the server process. These are denoted below. "tcp" is currently the only option available. A full reference for the inner syntax is below. -* `disable_cache` (optional) - A boolean. If true, this will disable the - read cache used by the physical storage subsystem. This will very - significantly impact performance. +* `disable_cache` (optional) - A boolean. If true, this will disable all caches + within Vault, including the read cache used by the physical storage + subsystem. This will very significantly impact performance. * `disable_mlock` (optional) - A boolean. If true, this will disable the server from executing the `mlock` syscall to prevent memory from being From 32601f442429a815eccb3c1bf017d3cbfb1d089d Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 21 Apr 2016 20:32:06 +0000 Subject: [PATCH 02/11] Make a non-caching but still locking variant of transit for when caches are disabled --- builtin/logical/transit/backend.go | 12 +- builtin/logical/transit/backend_test.go | 60 +++++-- builtin/logical/transit/caching_crud.go | 109 +++++++++++++ builtin/logical/transit/path_config.go | 48 ++++-- builtin/logical/transit/path_datakey.go | 6 +- builtin/logical/transit/path_decrypt.go | 6 +- builtin/logical/transit/path_encrypt.go | 16 +- builtin/logical/transit/path_keys.go | 47 +++--- builtin/logical/transit/path_rewrap.go | 8 +- builtin/logical/transit/path_rotate.go | 36 +++- builtin/logical/transit/policy.go | 208 ++---------------------- builtin/logical/transit/policy_crud.go | 185 +++++++++++++++++++++ builtin/logical/transit/policy_test.go | 64 +++++--- builtin/logical/transit/simple_crud.go | 88 ++++++++++ logical/system_view.go | 8 +- vault/core.go | 6 +- vault/dynamic_system_view.go | 6 +- vault/policy_store.go | 12 +- vault/policy_store_test.go | 2 +- 19 files changed, 613 insertions(+), 314 deletions(-) create mode 100644 builtin/logical/transit/caching_crud.go create mode 100644 builtin/logical/transit/policy_crud.go create mode 100644 builtin/logical/transit/simple_crud.go diff --git a/builtin/logical/transit/backend.go b/builtin/logical/transit/backend.go index 3bb20d16d9..bd50d8d2ed 100644 --- a/builtin/logical/transit/backend.go +++ b/builtin/logical/transit/backend.go @@ -6,7 +6,7 @@ import ( ) func Factory(conf *logical.BackendConfig) (logical.Backend, error) { - b := Backend() + b := Backend(conf) be, err := b.Backend.Setup(conf) if err != nil { return nil, err @@ -15,7 +15,7 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) { return be, nil } -func Backend() *backend { +func Backend(conf *logical.BackendConfig) *backend { var b backend b.Backend = &framework.Backend{ Paths: []*framework.Path{ @@ -33,8 +33,10 @@ func Backend() *backend { Secrets: []*framework.Secret{}, } - b.policies = policyCache{ - cache: map[string]*lockingPolicy{}, + if conf.System.CachingDisabled() { + b.policies = newSimplePolicyCRUD() + } else { + b.policies = newCachingPolicyCRUD() } return &b @@ -42,5 +44,5 @@ func Backend() *backend { type backend struct { *framework.Backend - policies policyCache + policies policyCRUD } diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index dc0efaad41..8e385f31b3 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -559,16 +559,31 @@ func TestPolicyFuzzing(t *testing.T) { return } - be := Backend() + var be *backend + sysView := logical.TestSystemView() + be = Backend(&logical.BackendConfig{ + System: sysView, + }) + testPolicyFuzzingCommon(t, be) + + sysView.CachingDisabledVal = true + be = Backend(&logical.BackendConfig{ + System: sysView, + }) + testPolicyFuzzingCommon(t, be) +} + +func testPolicyFuzzingCommon(t *testing.T, be *backend) { storage := &logical.LockingInmemStorage{} wg := sync.WaitGroup{} funcs := []string{"encrypt", "decrypt", "rotate", "change_min_version"} + //keys := []string{"test1", "test2", "test3", "test4", "test5"} keys := []string{"test1", "test2", "test3"} // This is the goroutine loop - doFuzzy := func() { + doFuzzy := func(id int) { // Check for panics, otherwise notify we're done defer func() { if err := recover(); err != nil { @@ -587,15 +602,24 @@ func TestPolicyFuzzing(t *testing.T) { } fd := &framework.FieldData{} + var retest bool + var chosenFunc, chosenKey string + + //t.Logf("Starting") for { // Stop after 10 seconds if time.Now().Sub(startTime) > 10*time.Second { + if retest { + t.Errorf("ended runtime on a retest, id is %d", id) + } return } // Pick a function and a key - chosenFunc := funcs[rand.Int()%len(funcs)] - chosenKey := keys[rand.Int()%len(keys)] + if !retest { + chosenFunc = funcs[rand.Int()%len(funcs)] + chosenKey = keys[rand.Int()%len(keys)] + } fd.Raw = map[string]interface{}{ "name": chosenKey, @@ -605,33 +629,36 @@ func TestPolicyFuzzing(t *testing.T) { // Try to write the key to make sure it exists _, err := be.pathPolicyWrite(req, fd) if err != nil { - t.Errorf("got an error: %v", err) + t.Fatalf("got an error: %v", err) return } switch chosenFunc { // Encrypt our plaintext and store the result case "encrypt": + //t.Logf("%s, %s", chosenFunc, chosenKey) fd.Raw["plaintext"] = base64.StdEncoding.EncodeToString([]byte(testPlaintext)) fd.Schema = be.pathEncrypt().Fields resp, err := be.pathEncryptWrite(req, fd) if err != nil { - t.Errorf("got an error: %v, resp is %#v", err, *resp) + t.Fatalf("got an error: %v, resp is %#v", err, *resp) return } latestEncryptedText[chosenKey] = resp.Data["ciphertext"].(string) // Rotate to a new key version case "rotate": + //t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id) fd.Schema = be.pathRotate().Fields resp, err := be.pathRotateWrite(req, fd) if err != nil { - t.Errorf("got an error: %v, resp is %#v", err, *resp) + t.Fatalf("got an error: %v, resp is %#v, chosenKey is %s", err, *resp, chosenKey) return } // Decrypt the ciphertext and compare the result case "decrypt": + //t.Logf("%s, %s", chosenFunc, chosenKey) ct := latestEncryptedText[chosenKey] if ct == "" { continue @@ -645,13 +672,14 @@ func TestPolicyFuzzing(t *testing.T) { if resp.Data["error"].(string) == ErrTooOld { continue } - t.Errorf("got an error: %v, resp is %#v, ciphertext was %s", err, *resp, latestEncryptedText[chosenKey]) - return + t.Errorf("got an error: %v, resp is %#v, ciphertext was %s, chosenKey is %s, id is %d", err, *resp, ct, chosenKey, id) + retest = true + continue } ptb64 := resp.Data["plaintext"].(string) pt, err := base64.StdEncoding.DecodeString(ptb64) if err != nil { - t.Errorf("got an error decoding base64 plaintext: %v", err) + t.Fatalf("got an error decoding base64 plaintext: %v", err) return } if string(pt) != testPlaintext { @@ -660,9 +688,10 @@ func TestPolicyFuzzing(t *testing.T) { // Change the min version, which also tests the archive functionality case "change_min_version": + //t.Logf("%s, %s", chosenFunc, chosenKey) resp, err := be.pathPolicyRead(req, fd) if err != nil { - t.Errorf("got an error reading policy %s: %v", chosenKey, err) + t.Fatalf("got an error reading policy %s: %v", chosenKey, err) return } latestVersion := resp.Data["latest_version"].(int) @@ -673,17 +702,22 @@ func TestPolicyFuzzing(t *testing.T) { fd.Schema = be.pathConfig().Fields resp, err = be.pathConfigWrite(req, fd) if err != nil { - t.Errorf("got an error setting min decryption version: %v", err) + t.Fatalf("got an error setting min decryption version: %v", err) return } } + + if retest { + t.Errorf("success, setting retest false, id is %d", id) + } + retest = false } } // Spawn 1000 of these workers for 10 seconds for i := 0; i < 1000; i++ { wg.Add(1) - go doFuzzy() + go doFuzzy(i) } // Wait for them all to finish diff --git a/builtin/logical/transit/caching_crud.go b/builtin/logical/transit/caching_crud.go new file mode 100644 index 0000000000..b0f3a9ccae --- /dev/null +++ b/builtin/logical/transit/caching_crud.go @@ -0,0 +1,109 @@ +package transit + +import ( + "sync" + + "github.com/hashicorp/vault/logical" +) + +// policyCache implements CRUD operations with a simple locking cache of +// policies +type cachingPolicyCRUD struct { + sync.RWMutex + cache map[string]lockingPolicy +} + +func newCachingPolicyCRUD() *cachingPolicyCRUD { + return &cachingPolicyCRUD{ + cache: map[string]lockingPolicy{}, + } +} + +func (p *cachingPolicyCRUD) getPolicy(storage logical.Storage, name string) (lockingPolicy, error) { + // We don't defer this since we may need to give it up and get a write lock + p.RLock() + + // First, see if we're in the cache -- if so, return that + if p.cache[name] != nil { + defer p.RUnlock() + return p.cache[name], nil + } + + // If we didn't find anything, we'll need to write into the cache, plus possibly + // persist the entry, so lock the cache + p.RUnlock() + p.Lock() + defer p.Unlock() + + // Check one more time to ensure that another process did not write during + // our lock switcheroo. + if p.cache[name] != nil { + return p.cache[name], nil + } + + return p.refreshPolicy(storage, name) +} + +func (p *cachingPolicyCRUD) refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error) { + // Check once more to ensure it hasn't been added to the cache since the lock was acquired + if p.cache[name] != nil { + return p.cache[name], nil + } + + // Note that we don't need to create the locking entry until the end, + // because the policy wasn't in the cache so we don't know about it, and we + // hold the cache lock so nothing else can be writing it in right now + policy, err := fetchPolicyFromStorage(storage, name) + if err != nil { + return nil, err + } + if policy == nil { + return nil, nil + } + + lp := &mutexLockingPolicy{ + policy: policy, + mutex: &sync.RWMutex{}, + } + p.cache[name] = lp + + return lp, nil +} + +// generatePolicy is used to create a new named policy with a randomly +// generated key. The caller should hold the write lock prior to calling this. +func (p *cachingPolicyCRUD) generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error) { + policy, err := generatePolicyCommon(p, storage, name, derived) + if err != nil { + return nil, err + } + + // Now we need to check again in the cache to ensure the policy wasn't + // created since we ran generatePolicy and then got the lock. A policy + // being created holds a write lock until it's done (starting from this + // point), so it'll be in the cache at this point. + if lp := p.cache[name]; lp != nil { + return lp, nil + } + + lp := &mutexLockingPolicy{ + policy: policy, + mutex: &sync.RWMutex{}, + } + p.cache[name] = lp + + // Return the policy + return lp, nil +} + +// deletePolicy deletes a policy +func (p *cachingPolicyCRUD) deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error { + err := deletePolicyCommon(p, lp, storage, name) + if err != nil { + return err + } + + delete(p.cache, name) + + return nil +} diff --git a/builtin/logical/transit/path_config.go b/builtin/logical/transit/path_config.go index 967b743412..cf4348bd57 100644 --- a/builtin/logical/transit/path_config.go +++ b/builtin/logical/transit/path_config.go @@ -42,26 +42,48 @@ func (b *backend) pathConfigWrite( name := d.Get("name").(string) // Check if the policy already exists - lp, err := b.policies.getPolicy(req, name) + lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return nil, err } if lp == nil { return logical.ErrorResponse( - fmt.Sprintf("no existing role named %s could be found", name)), + fmt.Sprintf("no existing key named %s could be found", name)), logical.ErrInvalidRequest } + currDeletionAllowed := lp.Policy().DeletionAllowed + currMinDecryptionVersion := lp.Policy().MinDecryptionVersion + + // Hold both locks since we want to ensure the policy doesn't change from underneath us + b.policies.Lock() + defer b.policies.Unlock() lp.Lock() defer lp.Unlock() + // Refresh in case it's changed since before we grabbed the lock + lp, err = b.policies.refreshPolicy(req.Storage, name) + if err != nil { + return nil, err + } + if lp == nil { + return nil, fmt.Errorf("error finding key %s after locking for changes", name) + } + // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { - return nil, fmt.Errorf("no existing role named %s could be found", name) + if lp.Policy() == nil { + return nil, fmt.Errorf("no existing key named %s could be found", name) } resp := &logical.Response{} + // Check for anything to have been updated since we got the policy + if currDeletionAllowed != lp.Policy().DeletionAllowed || + currMinDecryptionVersion != lp.Policy().MinDecryptionVersion { + resp.AddWarning("key configuration has changed since this endpoint was called, not updating") + return resp, nil + } + persistNeeded := false minDecryptionVersionRaw, ok := d.GetOk("min_decryption_version") @@ -78,12 +100,12 @@ func (b *backend) pathConfigWrite( } if minDecryptionVersion > 0 && - minDecryptionVersion != lp.policy.MinDecryptionVersion { - if minDecryptionVersion > lp.policy.LatestVersion { + minDecryptionVersion != lp.Policy().MinDecryptionVersion { + if minDecryptionVersion > lp.Policy().LatestVersion { return logical.ErrorResponse( - fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, lp.policy.LatestVersion)), nil + fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, lp.Policy().LatestVersion)), nil } - lp.policy.MinDecryptionVersion = minDecryptionVersion + lp.Policy().MinDecryptionVersion = minDecryptionVersion persistNeeded = true } } @@ -91,8 +113,8 @@ func (b *backend) pathConfigWrite( allowDeletionInt, ok := d.GetOk("deletion_allowed") if ok { allowDeletion := allowDeletionInt.(bool) - if allowDeletion != lp.policy.DeletionAllowed { - lp.policy.DeletionAllowed = allowDeletion + if allowDeletion != lp.Policy().DeletionAllowed { + lp.Policy().DeletionAllowed = allowDeletion persistNeeded = true } } @@ -100,8 +122,8 @@ func (b *backend) pathConfigWrite( // Add this as a guard here before persisting since we now require the min // decryption version to start at 1; even if it's not explicitly set here, // force the upgrade - if lp.policy.MinDecryptionVersion == 0 { - lp.policy.MinDecryptionVersion = 1 + if lp.Policy().MinDecryptionVersion == 0 { + lp.Policy().MinDecryptionVersion = 1 persistNeeded = true } @@ -109,7 +131,7 @@ func (b *backend) pathConfigWrite( return nil, nil } - return resp, lp.policy.Persist(req.Storage) + return resp, lp.Policy().Persist(req.Storage) } const pathConfigHelpSyn = `Configure a named encryption key` diff --git a/builtin/logical/transit/path_datakey.go b/builtin/logical/transit/path_datakey.go index f2bec71d09..9e64836730 100644 --- a/builtin/logical/transit/path_datakey.go +++ b/builtin/logical/transit/path_datakey.go @@ -73,7 +73,7 @@ func (b *backend) pathDatakeyWrite( } // Get the policy - lp, err := b.policies.getPolicy(req, name) + lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return nil, err } @@ -87,7 +87,7 @@ func (b *backend) pathDatakeyWrite( defer lp.RUnlock() // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { + if lp.Policy() == nil { return nil, fmt.Errorf("no existing policy named %s could be found", name) } @@ -107,7 +107,7 @@ func (b *backend) pathDatakeyWrite( return nil, err } - ciphertext, err := lp.policy.Encrypt(context, base64.StdEncoding.EncodeToString(newKey)) + ciphertext, err := lp.Policy().Encrypt(context, base64.StdEncoding.EncodeToString(newKey)) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go index 66191b6295..a58709858a 100644 --- a/builtin/logical/transit/path_decrypt.go +++ b/builtin/logical/transit/path_decrypt.go @@ -58,7 +58,7 @@ func (b *backend) pathDecryptWrite( } // Get the policy - lp, err := b.policies.getPolicy(req, name) + lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return nil, err } @@ -72,11 +72,11 @@ func (b *backend) pathDecryptWrite( defer lp.RUnlock() // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { + if lp.Policy() == nil { return nil, fmt.Errorf("no existing policy named %s could be found", name) } - plaintext, err := lp.policy.Decrypt(context, ciphertext) + plaintext, err := lp.Policy().Decrypt(context, ciphertext) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index fc2e2048f3..8c4cc91e7b 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -44,7 +44,7 @@ func (b *backend) pathEncrypt() *framework.Path { func (b *backend) pathEncryptExistenceCheck( req *logical.Request, d *framework.FieldData) (bool, error) { name := d.Get("name").(string) - lp, err := b.policies.getPolicy(req, name) + lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return false, err } @@ -72,37 +72,43 @@ func (b *backend) pathEncryptWrite( } // Get the policy - lp, err := b.policies.getPolicy(req, name) + lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return nil, err } - // Error if invalid policy + // Error or upsert if invalid policy if lp == nil { if req.Operation != logical.CreateOperation { return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } + // Get a write lock + b.policies.Lock() + isDerived := len(context) != 0 + // This also checks to make sure one hasn't been created since we grabbed the write lock lp, err = b.policies.generatePolicy(req.Storage, name, isDerived) // If the error is that the policy has been created in the interim we // will get the policy back, so only consider it an error if err is not // nil and we do not get a policy back if err != nil && lp != nil { + b.policies.Unlock() return nil, err } + b.policies.Unlock() } lp.RLock() defer lp.RUnlock() // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { + if lp.Policy() == nil { return nil, fmt.Errorf("no existing policy named %s could be found", name) } - ciphertext, err := lp.policy.Encrypt(context, value) + ciphertext, err := lp.Policy().Encrypt(context, value) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index 32587b6c60..02d7f882b3 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -36,20 +36,14 @@ func (b *backend) pathKeys() *framework.Path { func (b *backend) pathPolicyWrite( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + b.policies.Lock() + defer b.policies.Unlock() + name := d.Get("name").(string) derived := d.Get("derived").(bool) - // Check if the policy already exists - existing, err := b.policies.getPolicy(req, name) - if err != nil { - return nil, err - } - if existing != nil { - return nil, nil - } - - // Generate the policy - _, err = b.policies.generatePolicy(req.Storage, name, derived) + // Generate the policy; this will also check if it exists for safety + _, err := b.policies.generatePolicy(req.Storage, name, derived) return nil, err } @@ -57,7 +51,7 @@ func (b *backend) pathPolicyRead( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - lp, err := b.policies.getPolicy(req, name) + lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return nil, err } @@ -69,27 +63,27 @@ func (b *backend) pathPolicyRead( defer lp.RUnlock() // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { + if lp.Policy() == nil { return nil, fmt.Errorf("no existing policy named %s could be found", name) } // Return the response resp := &logical.Response{ Data: map[string]interface{}{ - "name": lp.policy.Name, - "cipher_mode": lp.policy.CipherMode, - "derived": lp.policy.Derived, - "deletion_allowed": lp.policy.DeletionAllowed, - "min_decryption_version": lp.policy.MinDecryptionVersion, - "latest_version": lp.policy.LatestVersion, + "name": lp.Policy().Name, + "cipher_mode": lp.Policy().CipherMode, + "derived": lp.Policy().Derived, + "deletion_allowed": lp.Policy().DeletionAllowed, + "min_decryption_version": lp.Policy().MinDecryptionVersion, + "latest_version": lp.Policy().LatestVersion, }, } - if lp.policy.Derived { - resp.Data["kdf_mode"] = lp.policy.KDFMode + if lp.Policy().Derived { + resp.Data["kdf_mode"] = lp.Policy().KDFMode } retKeys := map[string]int64{} - for k, v := range lp.policy.Keys { + for k, v := range lp.Policy().Keys { retKeys[strconv.Itoa(k)] = v.CreationTime } resp.Data["keys"] = retKeys @@ -101,7 +95,7 @@ func (b *backend) pathPolicyDelete( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - lp, err := b.policies.getPolicy(req, name) + lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return logical.ErrorResponse(fmt.Sprintf("error looking up policy %s, error is %s", name, err)), err } @@ -109,7 +103,12 @@ func (b *backend) pathPolicyDelete( return logical.ErrorResponse(fmt.Sprintf("no such key %s", name)), logical.ErrInvalidRequest } - err = b.policies.deletePolicy(req.Storage, name) + b.policies.Lock() + defer b.policies.Unlock() + lp.Lock() + defer lp.Unlock() + + err = b.policies.deletePolicy(req.Storage, lp, name) if err != nil { return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err } diff --git a/builtin/logical/transit/path_rewrap.go b/builtin/logical/transit/path_rewrap.go index d8d01bddca..2233689b3b 100644 --- a/builtin/logical/transit/path_rewrap.go +++ b/builtin/logical/transit/path_rewrap.go @@ -59,7 +59,7 @@ func (b *backend) pathRewrapWrite( } // Get the policy - lp, err := b.policies.getPolicy(req, name) + lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return nil, err } @@ -73,11 +73,11 @@ func (b *backend) pathRewrapWrite( defer lp.RUnlock() // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { + if lp.Policy() == nil { return nil, fmt.Errorf("no existing policy named %s could be found", name) } - plaintext, err := lp.policy.Decrypt(context, value) + plaintext, err := lp.Policy().Decrypt(context, value) if err != nil { switch err.(type) { case certutil.UserError: @@ -93,7 +93,7 @@ func (b *backend) pathRewrapWrite( return nil, fmt.Errorf("empty plaintext returned during rewrap") } - ciphertext, err := lp.policy.Encrypt(context, plaintext) + ciphertext, err := lp.Policy().Encrypt(context, plaintext) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_rotate.go b/builtin/logical/transit/path_rotate.go index 90bbc2e185..dbde58e289 100644 --- a/builtin/logical/transit/path_rotate.go +++ b/builtin/logical/transit/path_rotate.go @@ -31,26 +31,48 @@ func (b *backend) pathRotateWrite( name := d.Get("name").(string) // Get the policy - lp, err := b.policies.getPolicy(req, name) + lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return nil, err } // Error if invalid policy if lp == nil { - return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest + return logical.ErrorResponse("key not found"), logical.ErrInvalidRequest } + keyVersion := lp.Policy().LatestVersion + + // lock the policies object so we can refresh + b.policies.Lock() + defer b.policies.Unlock() lp.Lock() defer lp.Unlock() - // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { - return nil, fmt.Errorf("no existing policy named %s could be found", name) + // Refresh in case it's changed since before we grabbed the lock + lp, err = b.policies.refreshPolicy(req.Storage, name) + if err != nil { + return nil, err + } + if lp == nil { + return nil, fmt.Errorf("error finding key %s after locking for changes", name) } - // Generate the policy - err = lp.policy.rotate(req.Storage) + // Verify if wasn't deleted before we grabbed the lock + if lp.Policy() == nil { + return nil, fmt.Errorf("no existing key named %s could be found", name) + } + + // Make sure that the policy hasn't been rotated simultaneously + if keyVersion != lp.Policy().LatestVersion { + resp := &logical.Response{} + resp.AddWarning("key has been rotated since this endpoint was called; did not perform rotation") + return resp, nil + } + + //fmt.Printf("Rotating key %s, orig seen version is %d, currVersion is %d\n", name, keyVersion, lp.Policy().LatestVersion) + // Rotate the policy + err = lp.Policy().rotate(req.Storage) return nil, err } diff --git a/builtin/logical/transit/policy.go b/builtin/logical/transit/policy.go index 17d7cc00b3..ad47af7547 100644 --- a/builtin/logical/transit/policy.go +++ b/builtin/logical/transit/policy.go @@ -9,7 +9,6 @@ import ( "fmt" "strconv" "strings" - "sync" "time" "github.com/hashicorp/vault/helper/certutil" @@ -24,197 +23,6 @@ const ( ErrTooOld = "ciphertext version is disallowed by policy (too old)" ) -// policyCache implements a simple locking cache of policies -type policyCache struct { - sync.RWMutex - cache map[string]*lockingPolicy -} - -// getPolicy loads a policy into the cache or returns one already in the cache -func (p *policyCache) getPolicy(req *logical.Request, name string) (*lockingPolicy, error) { - // We don't defer this since we may need to give it up and get a write lock - p.RLock() - - // First, see if we're in the cache -- if so, return that - if p.cache[name] != nil { - defer p.RUnlock() - return p.cache[name], nil - } - - // If we didn't find anything, we'll need to write into the cache, plus possibly - // persist the entry, so lock the cache - p.RUnlock() - p.Lock() - defer p.Unlock() - - // Check one more time to ensure that another process did not write during - // our lock switcheroo. - if p.cache[name] != nil { - return p.cache[name], nil - } - - // Note that we don't need to create the locking entry until the end, - // because the policy wasn't in the cache so we don't know about it, and we - // hold the cache lock so nothing else can be writing it in right now - - // Check if the policy already exists - raw, err := req.Storage.Get("policy/" + name) - if err != nil { - return nil, err - } - if raw == nil { - return nil, nil - } - - // Decode the policy - policy := &Policy{ - Keys: KeyEntryMap{}, - } - err = json.Unmarshal(raw.Value, policy) - if err != nil { - return nil, err - } - - persistNeeded := false - // Ensure we've moved from Key -> Keys - if policy.Key != nil && len(policy.Key) > 0 { - policy.migrateKeyToKeysMap() - persistNeeded = true - } - - // With archiving, past assumptions about the length of the keys map are no longer valid - if policy.LatestVersion == 0 && len(policy.Keys) != 0 { - policy.LatestVersion = len(policy.Keys) - persistNeeded = true - } - - // We disallow setting the version to 0, since they start at 1 since moving - // to rotate-able keys, so update if it's set to 0 - if policy.MinDecryptionVersion == 0 { - policy.MinDecryptionVersion = 1 - persistNeeded = true - } - - // On first load after an upgrade, copy keys to the archive - if policy.ArchiveVersion == 0 { - persistNeeded = true - } - - if persistNeeded { - err = policy.Persist(req.Storage) - if err != nil { - return nil, err - } - } - - lp := &lockingPolicy{ - policy: policy, - } - p.cache[name] = lp - - return lp, nil -} - -// generatePolicy is used to create a new named policy with a randomly -// generated key -func (p *policyCache) generatePolicy(storage logical.Storage, name string, derived bool) (*lockingPolicy, error) { - // Ensure one with this name doesn't already exist - lp, err := p.getPolicy(&logical.Request{ - Storage: storage, - }, name) - if err != nil { - return nil, fmt.Errorf("error checking if policy already exists: %s", err) - } - if lp != nil { - return nil, fmt.Errorf("policy %s already exists", name) - } - - p.Lock() - defer p.Unlock() - - // Now we need to check again in the cache to ensure the policy wasn't - // created since we checked getPolicy. A policy being created holds a write - // lock until it's done, so it'll be in the cache at this point. - if lp := p.cache[name]; lp != nil { - return nil, fmt.Errorf("policy %s already exists", name) - } - - // Create the policy object - policy := &Policy{ - Name: name, - CipherMode: "aes-gcm", - Derived: derived, - } - if derived { - policy.KDFMode = kdfMode - } - - err = policy.rotate(storage) - if err != nil { - return nil, err - } - - lp = &lockingPolicy{ - policy: policy, - } - p.cache[name] = lp - - // Return the policy - return lp, nil -} - -// deletePolicy deletes a policy -func (p *policyCache) deletePolicy(storage logical.Storage, name string) error { - // Ensure one with this name exists - lp, err := p.getPolicy(&logical.Request{ - Storage: storage, - }, name) - if err != nil { - return fmt.Errorf("error checking if policy already exists: %s", err) - } - if lp == nil { - return fmt.Errorf("policy %s does not exist", name) - } - - p.Lock() - defer p.Unlock() - - lp = p.cache[name] - if lp == nil { - return fmt.Errorf("policy %s not found", name) - } - - // We need to ensure all other access has stopped - lp.Lock() - defer lp.Unlock() - - // Verify this hasn't changed - if !lp.policy.DeletionAllowed { - return fmt.Errorf("deletion not allowed for policy %s", name) - } - - err = storage.Delete("policy/" + name) - if err != nil { - return fmt.Errorf("error deleting policy %s: %s", name, err) - } - - err = storage.Delete("archive/" + name) - if err != nil { - return fmt.Errorf("error deleting archive %s: %s", name, err) - } - - lp.policy = nil - delete(p.cache, name) - - return nil -} - -// lockingPolicy holds a Policy guarded by a lock -type lockingPolicy struct { - sync.RWMutex - policy *Policy -} - // KeyEntry stores the key and metadata type KeyEntry struct { Key []byte `json:"key"` @@ -521,17 +329,17 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) { func (p *Policy) Decrypt(context []byte, value string) (string, error) { // Verify the prefix if !strings.HasPrefix(value, "vault:v") { - return "", certutil.UserError{Err: "invalid ciphertext"} + return "", certutil.UserError{Err: "invalid ciphertext: no prefix"} } splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, "vault:v"), ":", 2) if len(splitVerCiphertext) != 2 { - return "", certutil.UserError{Err: "invalid ciphertext"} + return "", certutil.UserError{Err: "invalid ciphertext: wrong number of fields"} } ver, err := strconv.Atoi(splitVerCiphertext[0]) if err != nil { - return "", certutil.UserError{Err: "invalid ciphertext"} + return "", certutil.UserError{Err: "invalid ciphertext: version number could not be decoded"} } if ver == 0 { @@ -540,6 +348,10 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) { ver = 1 } + if ver > p.LatestVersion { + return "", certutil.UserError{Err: "invalid ciphertext: version is too new"} + } + if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion { return "", certutil.UserError{Err: ErrTooOld} } @@ -560,7 +372,7 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) { // Decode the base64 decoded, err := base64.StdEncoding.DecodeString(splitVerCiphertext[1]) if err != nil { - return "", certutil.UserError{Err: "invalid ciphertext"} + return "", certutil.UserError{Err: "invalid ciphertext: could not decode base64"} } // Setup the cipher @@ -582,7 +394,7 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) { // Verify and Decrypt plain, err := gcm.Open(nil, nonce, ciphertext, nil) if err != nil { - return "", certutil.UserError{Err: "invalid ciphertext"} + return "", certutil.UserError{Err: "invalid ciphertext: unable to decrypt"} } return base64.StdEncoding.EncodeToString(plain), nil @@ -617,6 +429,8 @@ func (p *Policy) rotate(storage logical.Storage) error { p.MinDecryptionVersion = 1 } + //fmt.Printf("policy %s rotated to %d\n", p.Name, p.LatestVersion) + return p.Persist(storage) } diff --git a/builtin/logical/transit/policy_crud.go b/builtin/logical/transit/policy_crud.go new file mode 100644 index 0000000000..90c68c90fb --- /dev/null +++ b/builtin/logical/transit/policy_crud.go @@ -0,0 +1,185 @@ +package transit + +import ( + "encoding/json" + "fmt" + "sync" + + "github.com/hashicorp/vault/logical" +) + +type lockingPolicy interface { + Lock() + RLock() + Unlock() + RUnlock() + Policy() *Policy + SetPolicy(*Policy) +} + +type policyCRUD interface { + // getPolicy returns a lockingPolicy. It performs its own locking according + // to implementation. + getPolicy(storage logical.Storage, name string) (lockingPolicy, error) + + // refreshPolicy returns a lockingPolicy. It does not perform its own + // locking; a write lock must be held before calling. + refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error) + + // generatePolicy generates and returns a lockingPolicy. A write lock must + // be held before calling. + generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error) + + // deletePolicy deletes a lockingPolicy. A write lock must be held on both + // the CRUD implementation and the lockingPolicy before calling. + deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error + + // These are generally satisfied by embedded mutexes in the implementing struct + Lock() + RLock() + Unlock() + RUnlock() +} + +type mutexLockingPolicy struct { + mutex *sync.RWMutex + policy *Policy +} + +func (m *mutexLockingPolicy) Lock() { + m.mutex.Lock() +} + +func (m *mutexLockingPolicy) RLock() { + m.mutex.RLock() +} + +func (m *mutexLockingPolicy) Unlock() { + m.mutex.Unlock() +} + +func (m *mutexLockingPolicy) RUnlock() { + m.mutex.RUnlock() +} + +func (m *mutexLockingPolicy) Policy() *Policy { + return m.policy +} + +func (m *mutexLockingPolicy) SetPolicy(p *Policy) { + m.policy = p +} + +// The caller should hold the write lock when calling this +func fetchPolicyFromStorage(storage logical.Storage, name string) (*Policy, error) { + // Check if the policy already exists + raw, err := storage.Get("policy/" + name) + if err != nil { + return nil, err + } + if raw == nil { + return nil, nil + } + + // Decode the policy + policy := &Policy{ + Keys: KeyEntryMap{}, + } + err = json.Unmarshal(raw.Value, policy) + if err != nil { + return nil, err + } + + persistNeeded := false + // Ensure we've moved from Key -> Keys + if policy.Key != nil && len(policy.Key) > 0 { + policy.migrateKeyToKeysMap() + persistNeeded = true + } + + // With archiving, past assumptions about the length of the keys map are no longer valid + if policy.LatestVersion == 0 && len(policy.Keys) != 0 { + policy.LatestVersion = len(policy.Keys) + persistNeeded = true + } + + // We disallow setting the version to 0, since they start at 1 since moving + // to rotate-able keys, so update if it's set to 0 + if policy.MinDecryptionVersion == 0 { + policy.MinDecryptionVersion = 1 + persistNeeded = true + } + + // On first load after an upgrade, copy keys to the archive + if policy.ArchiveVersion == 0 { + persistNeeded = true + } + + if persistNeeded { + err = policy.Persist(storage) + if err != nil { + return nil, err + } + } + + return policy, nil +} + +// generatePolicyCommon is used to create a new named policy with a randomly +// generated key. The caller should have a write lock prior to calling this. +func generatePolicyCommon(p policyCRUD, storage logical.Storage, name string, derived bool) (*Policy, error) { + // Make sure this doesn't exist in case it was created before we got the write lock + policy, err := fetchPolicyFromStorage(storage, name) + if err != nil { + return nil, err + } + if policy != nil { + return policy, nil + } + + //log.Printf("generating a new policy with name %s", name) + + // Create the policy object + policy = &Policy{ + Name: name, + CipherMode: "aes-gcm", + Derived: derived, + } + if derived { + policy.KDFMode = kdfMode + } + + err = policy.rotate(storage) + if err != nil { + return nil, err + } + + return policy, err +} + +// deletePolicy deletes a policy. The caller should hold the write lock for both the policy and lockingPolicy prior to calling this. +func deletePolicyCommon(p policyCRUD, lp lockingPolicy, storage logical.Storage, name string) error { + if lp.Policy() == nil { + // This got deleted before we grabbed the lock + return fmt.Errorf("policy already deleted") + } + + // Verify this hasn't changed + if !lp.Policy().DeletionAllowed { + return fmt.Errorf("deletion not allowed for policy %s", name) + } + + err := storage.Delete("policy/" + name) + if err != nil { + return fmt.Errorf("error deleting policy %s: %s", name, err) + } + + err = storage.Delete("archive/" + name) + if err != nil { + return fmt.Errorf("error deleting archive %s: %s", name, err) + } + + lp.SetPolicy(nil) + + return nil +} diff --git a/builtin/logical/transit/policy_test.go b/builtin/logical/transit/policy_test.go index 04b3c3bbb1..af49c27d2c 100644 --- a/builtin/logical/transit/policy_test.go +++ b/builtin/logical/transit/policy_test.go @@ -16,19 +16,24 @@ func resetKeysArchive() { } func Test_KeyUpgrade(t *testing.T) { + testKeyUpgradeCommon(t, newSimplePolicyCRUD()) + testKeyUpgradeCommon(t, newCachingPolicyCRUD()) +} + +func testKeyUpgradeCommon(t *testing.T, policies policyCRUD) { storage := &logical.InmemStorage{} - policies := &policyCache{ - cache: map[string]*lockingPolicy{}, - } lp, err := policies.generatePolicy(storage, "test", false) if err != nil { t.Fatal(err) } if lp == nil { - t.Fatal("nil policy") + t.Fatal("nil lockingPolicy") } - policy := lp.policy + policy := lp.Policy() + if policy == nil { + t.Fatal("nil policy in lockingPolicy") + } testBytes := make([]byte, len(policy.Keys[1].Key)) copy(testBytes, policy.Keys[1].Key) @@ -48,6 +53,11 @@ func Test_KeyUpgrade(t *testing.T) { } func Test_ArchivingUpgrade(t *testing.T) { + testArchivingUpgradeCommon(t, newSimplePolicyCRUD()) + testArchivingUpgradeCommon(t, newCachingPolicyCRUD()) +} + +func testArchivingUpgradeCommon(t *testing.T, policies policyCRUD) { resetKeysArchive() // First, we generate a policy and rotate it a number of times. Each time @@ -56,19 +66,19 @@ func Test_ArchivingUpgrade(t *testing.T) { // zero and latest, respectively storage := &logical.InmemStorage{} - policies := &policyCache{ - cache: map[string]*lockingPolicy{}, - } lp, err := policies.generatePolicy(storage, "test", false) if err != nil { t.Fatal(err) } if lp == nil { - t.Fatal("policy is nil") + t.Fatal("nil lockingPolicy") } - policy := lp.policy + policy := lp.Policy() + if policy == nil { + t.Fatal("nil policy in lockingPolicy") + } // Store the initial key in the archive keysArchive = append(keysArchive, policy.Keys[1]) @@ -106,26 +116,35 @@ func Test_ArchivingUpgrade(t *testing.T) { t.Fatal(err) } - // Expire from the cache since we modified it under-the-hood - delete(policies.cache, "test") + // If it's a caching CRUD, expire from the cache since we modified it + // under-the-hood + if cachingCRUD, ok := policies.(*cachingPolicyCRUD); ok { + delete(cachingCRUD.cache, "test") + } // Now get the policy again; the upgrade should happen automatically - lp, err = policies.getPolicy(&logical.Request{ - Storage: storage, - }, "test") + lp, err = policies.getPolicy(storage, "test") if err != nil { t.Fatal(err) } if lp == nil { - t.Fatal("policy is nil") + t.Fatal("nil lockingPolicy") } - policy = lp.policy + policy = lp.Policy() + if policy == nil { + t.Fatal("nil policy in lockingPolicy") + } checkKeys(t, policy, storage, "upgrade", 10, 10, 10) } func Test_Archiving(t *testing.T) { + testArchivingCommon(t, newSimplePolicyCRUD()) + testArchivingCommon(t, newCachingPolicyCRUD()) +} + +func testArchivingCommon(t *testing.T, policies policyCRUD) { resetKeysArchive() // First, we generate a policy and rotate it a number of times. Each time @@ -135,19 +154,18 @@ func Test_Archiving(t *testing.T) { storage := &logical.InmemStorage{} - policies := &policyCache{ - cache: map[string]*lockingPolicy{}, - } - lp, err := policies.generatePolicy(storage, "test", false) if err != nil { t.Fatal(err) } if lp == nil { - t.Fatal("policy is nil") + t.Fatal("nil lockingPolicy") } - policy := lp.policy + policy := lp.Policy() + if policy == nil { + t.Fatal("nil policy in lockingPolicy") + } // Store the initial key in the archive keysArchive = append(keysArchive, policy.Keys[1]) diff --git a/builtin/logical/transit/simple_crud.go b/builtin/logical/transit/simple_crud.go new file mode 100644 index 0000000000..874f760d78 --- /dev/null +++ b/builtin/logical/transit/simple_crud.go @@ -0,0 +1,88 @@ +package transit + +import ( + "sync" + + "github.com/hashicorp/vault/logical" +) + +// Directly implements CRUD operations without caching, mapped to the backend, +// but implements locking to ensure that we can't overwrite data on the backend +// from multiple operators +type simplePolicyCRUD struct { + sync.RWMutex + locks map[string]*sync.RWMutex + locksMapMutex sync.RWMutex +} + +func newSimplePolicyCRUD() *simplePolicyCRUD { + return &simplePolicyCRUD{ + locks: map[string]*sync.RWMutex{}, + } +} + +func (p *simplePolicyCRUD) ensureLockExists(name string) { + p.locksMapMutex.RLock() + + if p.locks[name] == nil { + p.locksMapMutex.RUnlock() + p.locksMapMutex.Lock() + // Make sure nothing has appeared since we switched the lock type + if p.locks[name] == nil { + p.locks[name] = &sync.RWMutex{} + } + p.locksMapMutex.Unlock() + return + } + + p.locksMapMutex.RUnlock() +} + +func (p *simplePolicyCRUD) getPolicy(storage logical.Storage, name string) (lockingPolicy, error) { + // Use a write lock since fetching the policy can cause a need for upgrade persistence + p.Lock() + defer p.Unlock() + + return p.refreshPolicy(storage, name) +} + +func (p *simplePolicyCRUD) refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error) { + p.ensureLockExists(name) + + policy, err := fetchPolicyFromStorage(storage, name) + if err != nil { + return nil, err + } + if policy == nil { + return nil, nil + } + + lp := &mutexLockingPolicy{ + policy: policy, + mutex: p.locks[name], + } + + return lp, nil +} + +// The caller must hold the write lock when calling this +func (p *simplePolicyCRUD) generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error) { + p.ensureLockExists(name) + + policy, err := generatePolicyCommon(p, storage, name, derived) + if err != nil { + return nil, err + } + + lp := &mutexLockingPolicy{ + policy: policy, + mutex: p.locks[name], + } + + return lp, nil +} + +// The caller must hold the write lock when calling this +func (p *simplePolicyCRUD) deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error { + return deletePolicyCommon(p, lp, storage, name) +} diff --git a/logical/system_view.go b/logical/system_view.go index 4e26300cb0..d20bf0c373 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -29,7 +29,7 @@ type SystemView interface { // Returns true if caching is disabled. If true, no caches should be used, // despite known slowdowns. - CacheDisabled() bool + CachingDisabled() bool } type StaticSystemView struct { @@ -37,7 +37,7 @@ type StaticSystemView struct { MaxLeaseTTLVal time.Duration SudoPrivilegeVal bool TaintedVal bool - CacheDisabledVal bool + CachingDisabledVal bool } func (d StaticSystemView) DefaultLeaseTTL() time.Duration { @@ -56,6 +56,6 @@ func (d StaticSystemView) Tainted() bool { return d.TaintedVal } -func (d StaticSystemView) CacheDisabled() bool { - return d.CacheDisabledVal +func (d StaticSystemView) CachingDisabled() bool { + return d.CachingDisabledVal } diff --git a/vault/core.go b/vault/core.go index 12ceefb923..44dcd298e2 100644 --- a/vault/core.go +++ b/vault/core.go @@ -219,8 +219,8 @@ type Core struct { logger *log.Logger - // cacheDisabled indicates whether caches are disabled - cacheDisabled bool + // cachingDisabled indicates whether caches are disabled + cachingDisabled bool } // CoreConfig is used to parameterize a core @@ -318,7 +318,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { logger: conf.Logger, defaultLeaseTTL: conf.DefaultLeaseTTL, maxLeaseTTL: conf.MaxLeaseTTL, - cacheDisabled: conf.DisableCache, + cachingDisabled: conf.DisableCache, } // Setup the backends diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index b4ef6a77ae..8dc806de62 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -70,7 +70,7 @@ func (d dynamicSystemView) Tainted() bool { return d.mountEntry.Tainted } -// CacheDisabled indicates whether to use caching behavior -func (d dynamicSystemView) CacheDisabled() bool { - return d.core.cacheDisabled +// CachingDisabled indicates whether to use caching behavior +func (d dynamicSystemView) CachingDisabled() bool { + return d.core.cachingDisabled } diff --git a/vault/policy_store.go b/vault/policy_store.go index d4729f740b..862002a426 100644 --- a/vault/policy_store.go +++ b/vault/policy_store.go @@ -40,7 +40,7 @@ func NewPolicyStore(view *BarrierView, system logical.SystemView) *PolicyStore { view: view, system: system, } - if !system.CacheDisabled() { + if !system.CachingDisabled() { cache, _ := lru.New2Q(policyCacheSize) p.lru = cache } @@ -100,7 +100,7 @@ func (ps *PolicyStore) SetPolicy(p *Policy) error { return fmt.Errorf("failed to persist policy: %v", err) } - if !ps.system.CacheDisabled() { + if !ps.system.CachingDisabled() { // Update the LRU cache ps.lru.Add(p.Name, p) } @@ -110,7 +110,7 @@ func (ps *PolicyStore) SetPolicy(p *Policy) error { // GetPolicy is used to fetch the named policy func (ps *PolicyStore) GetPolicy(name string) (*Policy, error) { defer metrics.MeasureSince([]string{"policy", "get_policy"}, time.Now()) - if !ps.system.CacheDisabled() { + if !ps.system.CachingDisabled() { // Check for cached policy if raw, ok := ps.lru.Get(name); ok { return raw.(*Policy), nil @@ -120,7 +120,7 @@ func (ps *PolicyStore) GetPolicy(name string) (*Policy, error) { // Special case the root policy if name == "root" { p := &Policy{Name: "root"} - if !ps.system.CacheDisabled() { + if !ps.system.CachingDisabled() { ps.lru.Add(p.Name, p) } return p, nil @@ -163,7 +163,7 @@ func (ps *PolicyStore) GetPolicy(name string) (*Policy, error) { policy = p } - if !ps.system.CacheDisabled() { + if !ps.system.CachingDisabled() { // Update the LRU cache ps.lru.Add(name, policy) } @@ -192,7 +192,7 @@ func (ps *PolicyStore) DeletePolicy(name string) error { return fmt.Errorf("failed to delete policy: %v", err) } - if !ps.system.CacheDisabled() { + if !ps.system.CachingDisabled() { // Clear the cache ps.lru.Remove(name) } diff --git a/vault/policy_store_test.go b/vault/policy_store_test.go index 456bc8375b..05cbd1c79e 100644 --- a/vault/policy_store_test.go +++ b/vault/policy_store_test.go @@ -16,7 +16,7 @@ func mockPolicyStore(t *testing.T) *PolicyStore { func mockPolicyStoreNoCache(t *testing.T) *PolicyStore { sysView := logical.TestSystemView() - sysView.CacheDisabledVal = true + sysView.CachingDisabledVal = true _, barrier, _ := mockBarrier(t) view := NewBarrierView(barrier, "foo/") p := NewPolicyStore(view, sysView) From 634cea72d7590cf2ca555ed8fb9df7b460cd1564 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 21 Apr 2016 21:08:44 +0000 Subject: [PATCH 03/11] Fix up commenting and some minor tidbits --- builtin/logical/transit/backend_test.go | 24 +++--------------------- builtin/logical/transit/caching_crud.go | 11 ++++++----- builtin/logical/transit/path_config.go | 8 ++++++-- builtin/logical/transit/path_keys.go | 14 ++++++++++++++ builtin/logical/transit/path_rotate.go | 2 +- builtin/logical/transit/policy_crud.go | 11 +++++++---- builtin/logical/transit/simple_crud.go | 10 ++++++---- 7 files changed, 43 insertions(+), 37 deletions(-) diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index 8e385f31b3..3b8319913f 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -602,24 +602,18 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) { } fd := &framework.FieldData{} - var retest bool var chosenFunc, chosenKey string //t.Logf("Starting") for { // Stop after 10 seconds if time.Now().Sub(startTime) > 10*time.Second { - if retest { - t.Errorf("ended runtime on a retest, id is %d", id) - } return } // Pick a function and a key - if !retest { - chosenFunc = funcs[rand.Int()%len(funcs)] - chosenKey = keys[rand.Int()%len(keys)] - } + chosenFunc = funcs[rand.Int()%len(funcs)] + chosenKey = keys[rand.Int()%len(keys)] fd.Raw = map[string]interface{}{ "name": chosenKey, @@ -630,7 +624,6 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) { _, err := be.pathPolicyWrite(req, fd) if err != nil { t.Fatalf("got an error: %v", err) - return } switch chosenFunc { @@ -642,7 +635,6 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) { resp, err := be.pathEncryptWrite(req, fd) if err != nil { t.Fatalf("got an error: %v, resp is %#v", err, *resp) - return } latestEncryptedText[chosenKey] = resp.Data["ciphertext"].(string) @@ -653,7 +645,6 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) { resp, err := be.pathRotateWrite(req, fd) if err != nil { t.Fatalf("got an error: %v, resp is %#v, chosenKey is %s", err, *resp, chosenKey) - return } // Decrypt the ciphertext and compare the result @@ -672,9 +663,7 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) { if resp.Data["error"].(string) == ErrTooOld { continue } - t.Errorf("got an error: %v, resp is %#v, ciphertext was %s, chosenKey is %s, id is %d", err, *resp, ct, chosenKey, id) - retest = true - continue + t.Fatalf("got an error: %v, resp is %#v, ciphertext was %s, chosenKey is %s, id is %d", err, *resp, ct, chosenKey, id) } ptb64 := resp.Data["plaintext"].(string) pt, err := base64.StdEncoding.DecodeString(ptb64) @@ -692,7 +681,6 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) { resp, err := be.pathPolicyRead(req, fd) if err != nil { t.Fatalf("got an error reading policy %s: %v", chosenKey, err) - return } latestVersion := resp.Data["latest_version"].(int) @@ -703,14 +691,8 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) { resp, err = be.pathConfigWrite(req, fd) if err != nil { t.Fatalf("got an error setting min decryption version: %v", err) - return } } - - if retest { - t.Errorf("success, setting retest false, id is %d", id) - } - retest = false } } diff --git a/builtin/logical/transit/caching_crud.go b/builtin/logical/transit/caching_crud.go index b0f3a9ccae..8759a4228c 100644 --- a/builtin/logical/transit/caching_crud.go +++ b/builtin/logical/transit/caching_crud.go @@ -6,8 +6,8 @@ import ( "github.com/hashicorp/vault/logical" ) -// policyCache implements CRUD operations with a simple locking cache of -// policies +// cachingPolicyCRUD implements CRUD operations with a simple locking cache of +// policies in memory type cachingPolicyCRUD struct { sync.RWMutex cache map[string]lockingPolicy @@ -19,6 +19,7 @@ func newCachingPolicyCRUD() *cachingPolicyCRUD { } } +// See general comments on the interface method func (p *cachingPolicyCRUD) getPolicy(storage logical.Storage, name string) (lockingPolicy, error) { // We don't defer this since we may need to give it up and get a write lock p.RLock() @@ -44,6 +45,7 @@ func (p *cachingPolicyCRUD) getPolicy(storage logical.Storage, name string) (loc return p.refreshPolicy(storage, name) } +// See general comments on the interface method func (p *cachingPolicyCRUD) refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error) { // Check once more to ensure it hasn't been added to the cache since the lock was acquired if p.cache[name] != nil { @@ -70,8 +72,7 @@ func (p *cachingPolicyCRUD) refreshPolicy(storage logical.Storage, name string) return lp, nil } -// generatePolicy is used to create a new named policy with a randomly -// generated key. The caller should hold the write lock prior to calling this. +// See general comments on the interface method func (p *cachingPolicyCRUD) generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error) { policy, err := generatePolicyCommon(p, storage, name, derived) if err != nil { @@ -96,7 +97,7 @@ func (p *cachingPolicyCRUD) generatePolicy(storage logical.Storage, name string, return lp, nil } -// deletePolicy deletes a policy +// See general comments on the interface method func (p *cachingPolicyCRUD) deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error { err := deletePolicyCommon(p, lp, storage, name) if err != nil { diff --git a/builtin/logical/transit/path_config.go b/builtin/logical/transit/path_config.go index cf4348bd57..7395fbe106 100644 --- a/builtin/logical/transit/path_config.go +++ b/builtin/logical/transit/path_config.go @@ -52,10 +52,14 @@ func (b *backend) pathConfigWrite( logical.ErrInvalidRequest } + // Store some values so we can detect if the policy changed after locking + lp.RLock() currDeletionAllowed := lp.Policy().DeletionAllowed currMinDecryptionVersion := lp.Policy().MinDecryptionVersion + lp.RUnlock() - // Hold both locks since we want to ensure the policy doesn't change from underneath us + // Hold both locks since we want to ensure the policy doesn't change from + // underneath us b.policies.Lock() defer b.policies.Unlock() lp.Lock() @@ -77,7 +81,7 @@ func (b *backend) pathConfigWrite( resp := &logical.Response{} - // Check for anything to have been updated since we got the policy + // Check for anything to have been updated since we got the write lock if currDeletionAllowed != lp.Policy().DeletionAllowed || currMinDecryptionVersion != lp.Policy().MinDecryptionVersion { resp.AddWarning("key configuration has changed since this endpoint was called, not updating") diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index 02d7f882b3..32cbad4857 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -36,6 +36,7 @@ func (b *backend) pathKeys() *framework.Path { func (b *backend) pathPolicyWrite( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + // Grab a write lock right off the bat b.policies.Lock() defer b.policies.Unlock() @@ -95,6 +96,7 @@ func (b *backend) pathPolicyDelete( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) + // Some sanity checking lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return logical.ErrorResponse(fmt.Sprintf("error looking up policy %s, error is %s", name, err)), err @@ -103,11 +105,23 @@ func (b *backend) pathPolicyDelete( return logical.ErrorResponse(fmt.Sprintf("no such key %s", name)), logical.ErrInvalidRequest } + // Hold both locks since we'll be affecting both the cache (if it exists) + // and the locking policy itself b.policies.Lock() defer b.policies.Unlock() lp.Lock() defer lp.Unlock() + // Make sure that we have up-to-date values since deletePolicy will check + // things like whether deletion is allowed + lp, err = b.policies.refreshPolicy(req.Storage, name) + if err != nil { + return nil, err + } + if lp == nil { + return nil, fmt.Errorf("error finding key %s after locking for deletion", name) + } + err = b.policies.deletePolicy(req.Storage, lp, name) if err != nil { return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err diff --git a/builtin/logical/transit/path_rotate.go b/builtin/logical/transit/path_rotate.go index dbde58e289..93e1951a91 100644 --- a/builtin/logical/transit/path_rotate.go +++ b/builtin/logical/transit/path_rotate.go @@ -41,6 +41,7 @@ func (b *backend) pathRotateWrite( return logical.ErrorResponse("key not found"), logical.ErrInvalidRequest } + // Store so we can detect later if this has changed out from under us keyVersion := lp.Policy().LatestVersion // lock the policies object so we can refresh @@ -70,7 +71,6 @@ func (b *backend) pathRotateWrite( return resp, nil } - //fmt.Printf("Rotating key %s, orig seen version is %d, currVersion is %d\n", name, keyVersion, lp.Policy().LatestVersion) // Rotate the policy err = lp.Policy().rotate(req.Storage) diff --git a/builtin/logical/transit/policy_crud.go b/builtin/logical/transit/policy_crud.go index 90c68c90fb..0a6b817b44 100644 --- a/builtin/logical/transit/policy_crud.go +++ b/builtin/logical/transit/policy_crud.go @@ -41,6 +41,9 @@ type policyCRUD interface { RUnlock() } +// The mutex is kept separate from the struct since we may set it to its own +// mutex (if the object is shared) or a shared mutext (if the object isn't +// shared and only the locking is) type mutexLockingPolicy struct { mutex *sync.RWMutex policy *Policy @@ -70,7 +73,8 @@ func (m *mutexLockingPolicy) SetPolicy(p *Policy) { m.policy = p } -// The caller should hold the write lock when calling this +// fetchPolicyFromStorage fetches the policy from backend storage. The caller +// should hold the write lock when calling this, to handle upgrades. func fetchPolicyFromStorage(storage logical.Storage, name string) (*Policy, error) { // Check if the policy already exists raw, err := storage.Get("policy/" + name) @@ -137,8 +141,6 @@ func generatePolicyCommon(p policyCRUD, storage logical.Storage, name string, de return policy, nil } - //log.Printf("generating a new policy with name %s", name) - // Create the policy object policy = &Policy{ Name: name, @@ -157,7 +159,8 @@ func generatePolicyCommon(p policyCRUD, storage logical.Storage, name string, de return policy, err } -// deletePolicy deletes a policy. The caller should hold the write lock for both the policy and lockingPolicy prior to calling this. +// deletePolicyCommon deletes a policy. The caller should hold the write lock +// for both the policy and lockingPolicy prior to calling this. func deletePolicyCommon(p policyCRUD, lp lockingPolicy, storage logical.Storage, name string) error { if lp.Policy() == nil { // This got deleted before we grabbed the lock diff --git a/builtin/logical/transit/simple_crud.go b/builtin/logical/transit/simple_crud.go index 874f760d78..8680a0cbea 100644 --- a/builtin/logical/transit/simple_crud.go +++ b/builtin/logical/transit/simple_crud.go @@ -7,8 +7,8 @@ import ( ) // Directly implements CRUD operations without caching, mapped to the backend, -// but implements locking to ensure that we can't overwrite data on the backend -// from multiple operators +// but implements shared locking to ensure that we can't overwrite data on the +// backend from multiple operators type simplePolicyCRUD struct { sync.RWMutex locks map[string]*sync.RWMutex @@ -38,6 +38,7 @@ func (p *simplePolicyCRUD) ensureLockExists(name string) { p.locksMapMutex.RUnlock() } +// See general comments on the interface method func (p *simplePolicyCRUD) getPolicy(storage logical.Storage, name string) (lockingPolicy, error) { // Use a write lock since fetching the policy can cause a need for upgrade persistence p.Lock() @@ -46,6 +47,7 @@ func (p *simplePolicyCRUD) getPolicy(storage logical.Storage, name string) (lock return p.refreshPolicy(storage, name) } +// See general comments on the interface method func (p *simplePolicyCRUD) refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error) { p.ensureLockExists(name) @@ -65,7 +67,7 @@ func (p *simplePolicyCRUD) refreshPolicy(storage logical.Storage, name string) ( return lp, nil } -// The caller must hold the write lock when calling this +// See general comments on the interface method func (p *simplePolicyCRUD) generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error) { p.ensureLockExists(name) @@ -82,7 +84,7 @@ func (p *simplePolicyCRUD) generatePolicy(storage logical.Storage, name string, return lp, nil } -// The caller must hold the write lock when calling this +// See general comments on the interface method func (p *simplePolicyCRUD) deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error { return deletePolicyCommon(p, lp, storage, name) } From ddec2ed86b6404d8a690527d42d0fd1febf5ff4f Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Fri, 22 Apr 2016 13:41:32 +0000 Subject: [PATCH 04/11] Slightly nicer check for LRU in policy store --- vault/policy_store.go | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/vault/policy_store.go b/vault/policy_store.go index 862002a426..90b6e23416 100644 --- a/vault/policy_store.go +++ b/vault/policy_store.go @@ -22,9 +22,8 @@ const ( // PolicyStore is used to provide durable storage of policy, and to // manage ACLs associated with them. type PolicyStore struct { - view *BarrierView - lru *lru.TwoQueueCache - system logical.SystemView + view *BarrierView + lru *lru.TwoQueueCache } // PolicyEntry is used to store a policy by name @@ -37,8 +36,7 @@ type PolicyEntry struct { // using a given view. It used used to durable store and manage named policy. func NewPolicyStore(view *BarrierView, system logical.SystemView) *PolicyStore { p := &PolicyStore{ - view: view, - system: system, + view: view, } if !system.CachingDisabled() { cache, _ := lru.New2Q(policyCacheSize) @@ -100,7 +98,7 @@ func (ps *PolicyStore) SetPolicy(p *Policy) error { return fmt.Errorf("failed to persist policy: %v", err) } - if !ps.system.CachingDisabled() { + if ps.lru != nil { // Update the LRU cache ps.lru.Add(p.Name, p) } @@ -110,7 +108,7 @@ func (ps *PolicyStore) SetPolicy(p *Policy) error { // GetPolicy is used to fetch the named policy func (ps *PolicyStore) GetPolicy(name string) (*Policy, error) { defer metrics.MeasureSince([]string{"policy", "get_policy"}, time.Now()) - if !ps.system.CachingDisabled() { + if ps.lru != nil { // Check for cached policy if raw, ok := ps.lru.Get(name); ok { return raw.(*Policy), nil @@ -120,7 +118,7 @@ func (ps *PolicyStore) GetPolicy(name string) (*Policy, error) { // Special case the root policy if name == "root" { p := &Policy{Name: "root"} - if !ps.system.CachingDisabled() { + if ps.lru != nil { ps.lru.Add(p.Name, p) } return p, nil @@ -163,7 +161,7 @@ func (ps *PolicyStore) GetPolicy(name string) (*Policy, error) { policy = p } - if !ps.system.CachingDisabled() { + if ps.lru != nil { // Update the LRU cache ps.lru.Add(name, policy) } @@ -192,7 +190,7 @@ func (ps *PolicyStore) DeletePolicy(name string) error { return fmt.Errorf("failed to delete policy: %v", err) } - if !ps.system.CachingDisabled() { + if ps.lru != nil { // Clear the cache ps.lru.Remove(name) } From 3ab71ca239fd1f67c1226660fb2c6091cb771714 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Fri, 22 Apr 2016 16:21:27 +0000 Subject: [PATCH 05/11] Address feedback --- builtin/logical/transit/policy_crud.go | 2 +- builtin/logical/transit/simple_crud.go | 20 ++++++-------------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/builtin/logical/transit/policy_crud.go b/builtin/logical/transit/policy_crud.go index 0a6b817b44..46472de962 100644 --- a/builtin/logical/transit/policy_crud.go +++ b/builtin/logical/transit/policy_crud.go @@ -42,7 +42,7 @@ type policyCRUD interface { } // The mutex is kept separate from the struct since we may set it to its own -// mutex (if the object is shared) or a shared mutext (if the object isn't +// mutex (if the object is shared) or a shared mutex (if the object isn't // shared and only the locking is) type mutexLockingPolicy struct { mutex *sync.RWMutex diff --git a/builtin/logical/transit/simple_crud.go b/builtin/logical/transit/simple_crud.go index 8680a0cbea..6538c7eeae 100644 --- a/builtin/logical/transit/simple_crud.go +++ b/builtin/logical/transit/simple_crud.go @@ -11,8 +11,7 @@ import ( // backend from multiple operators type simplePolicyCRUD struct { sync.RWMutex - locks map[string]*sync.RWMutex - locksMapMutex sync.RWMutex + locks map[string]*sync.RWMutex } func newSimplePolicyCRUD() *simplePolicyCRUD { @@ -21,21 +20,14 @@ func newSimplePolicyCRUD() *simplePolicyCRUD { } } +// The write lock must be held before calling this; for this CRUD type this +// should always be the case, since the only method not requiring a write lock +// when called is getPolicy, and that itself grabs a write lock before calling +// refreshPolicy func (p *simplePolicyCRUD) ensureLockExists(name string) { - p.locksMapMutex.RLock() - if p.locks[name] == nil { - p.locksMapMutex.RUnlock() - p.locksMapMutex.Lock() - // Make sure nothing has appeared since we switched the lock type - if p.locks[name] == nil { - p.locks[name] = &sync.RWMutex{} - } - p.locksMapMutex.Unlock() - return + p.locks[name] = &sync.RWMutex{} } - - p.locksMapMutex.RUnlock() } // See general comments on the interface method From c598a12ab98390fb59ec6023af756c4ae77c814d Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Tue, 26 Apr 2016 15:39:19 +0000 Subject: [PATCH 06/11] Switch to lockManager --- builtin/logical/transit/backend.go | 8 +- builtin/logical/transit/backend_test.go | 8 +- builtin/logical/transit/caching_crud.go | 110 -------- builtin/logical/transit/lock_manager.go | 337 ++++++++++++++++++++++++ builtin/logical/transit/path_config.go | 53 ++-- builtin/logical/transit/path_datakey.go | 17 +- builtin/logical/transit/path_decrypt.go | 16 +- builtin/logical/transit/path_encrypt.go | 59 ++--- builtin/logical/transit/path_keys.go | 79 +++--- builtin/logical/transit/path_rewrap.go | 17 +- builtin/logical/transit/path_rotate.go | 28 +- builtin/logical/transit/policy.go | 61 +++++ builtin/logical/transit/policy_crud.go | 188 ------------- builtin/logical/transit/policy_test.go | 133 +++++----- builtin/logical/transit/simple_crud.go | 82 ------ 15 files changed, 565 insertions(+), 631 deletions(-) delete mode 100644 builtin/logical/transit/caching_crud.go create mode 100644 builtin/logical/transit/lock_manager.go delete mode 100644 builtin/logical/transit/policy_crud.go delete mode 100644 builtin/logical/transit/simple_crud.go diff --git a/builtin/logical/transit/backend.go b/builtin/logical/transit/backend.go index bd50d8d2ed..94d07cbd92 100644 --- a/builtin/logical/transit/backend.go +++ b/builtin/logical/transit/backend.go @@ -33,16 +33,12 @@ func Backend(conf *logical.BackendConfig) *backend { Secrets: []*framework.Secret{}, } - if conf.System.CachingDisabled() { - b.policies = newSimplePolicyCRUD() - } else { - b.policies = newCachingPolicyCRUD() - } + b.lm = newLockManager(conf.System.CachingDisabled()) return &b } type backend struct { *framework.Backend - policies policyCRUD + lm *lockManager } diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index 3b8319913f..3b31cc2bbf 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -604,7 +604,7 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) { var chosenFunc, chosenKey string - //t.Logf("Starting") + //t.Errorf("Starting %d", id) for { // Stop after 10 seconds if time.Now().Sub(startTime) > 10*time.Second { @@ -629,7 +629,7 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) { switch chosenFunc { // Encrypt our plaintext and store the result case "encrypt": - //t.Logf("%s, %s", chosenFunc, chosenKey) + //t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id) fd.Raw["plaintext"] = base64.StdEncoding.EncodeToString([]byte(testPlaintext)) fd.Schema = be.pathEncrypt().Fields resp, err := be.pathEncryptWrite(req, fd) @@ -649,7 +649,7 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) { // Decrypt the ciphertext and compare the result case "decrypt": - //t.Logf("%s, %s", chosenFunc, chosenKey) + //t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id) ct := latestEncryptedText[chosenKey] if ct == "" { continue @@ -677,7 +677,7 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) { // Change the min version, which also tests the archive functionality case "change_min_version": - //t.Logf("%s, %s", chosenFunc, chosenKey) + //t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id) resp, err := be.pathPolicyRead(req, fd) if err != nil { t.Fatalf("got an error reading policy %s: %v", chosenKey, err) diff --git a/builtin/logical/transit/caching_crud.go b/builtin/logical/transit/caching_crud.go deleted file mode 100644 index 8759a4228c..0000000000 --- a/builtin/logical/transit/caching_crud.go +++ /dev/null @@ -1,110 +0,0 @@ -package transit - -import ( - "sync" - - "github.com/hashicorp/vault/logical" -) - -// cachingPolicyCRUD implements CRUD operations with a simple locking cache of -// policies in memory -type cachingPolicyCRUD struct { - sync.RWMutex - cache map[string]lockingPolicy -} - -func newCachingPolicyCRUD() *cachingPolicyCRUD { - return &cachingPolicyCRUD{ - cache: map[string]lockingPolicy{}, - } -} - -// See general comments on the interface method -func (p *cachingPolicyCRUD) getPolicy(storage logical.Storage, name string) (lockingPolicy, error) { - // We don't defer this since we may need to give it up and get a write lock - p.RLock() - - // First, see if we're in the cache -- if so, return that - if p.cache[name] != nil { - defer p.RUnlock() - return p.cache[name], nil - } - - // If we didn't find anything, we'll need to write into the cache, plus possibly - // persist the entry, so lock the cache - p.RUnlock() - p.Lock() - defer p.Unlock() - - // Check one more time to ensure that another process did not write during - // our lock switcheroo. - if p.cache[name] != nil { - return p.cache[name], nil - } - - return p.refreshPolicy(storage, name) -} - -// See general comments on the interface method -func (p *cachingPolicyCRUD) refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error) { - // Check once more to ensure it hasn't been added to the cache since the lock was acquired - if p.cache[name] != nil { - return p.cache[name], nil - } - - // Note that we don't need to create the locking entry until the end, - // because the policy wasn't in the cache so we don't know about it, and we - // hold the cache lock so nothing else can be writing it in right now - policy, err := fetchPolicyFromStorage(storage, name) - if err != nil { - return nil, err - } - if policy == nil { - return nil, nil - } - - lp := &mutexLockingPolicy{ - policy: policy, - mutex: &sync.RWMutex{}, - } - p.cache[name] = lp - - return lp, nil -} - -// See general comments on the interface method -func (p *cachingPolicyCRUD) generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error) { - policy, err := generatePolicyCommon(p, storage, name, derived) - if err != nil { - return nil, err - } - - // Now we need to check again in the cache to ensure the policy wasn't - // created since we ran generatePolicy and then got the lock. A policy - // being created holds a write lock until it's done (starting from this - // point), so it'll be in the cache at this point. - if lp := p.cache[name]; lp != nil { - return lp, nil - } - - lp := &mutexLockingPolicy{ - policy: policy, - mutex: &sync.RWMutex{}, - } - p.cache[name] = lp - - // Return the policy - return lp, nil -} - -// See general comments on the interface method -func (p *cachingPolicyCRUD) deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error { - err := deletePolicyCommon(p, lp, storage, name) - if err != nil { - return err - } - - delete(p.cache, name) - - return nil -} diff --git a/builtin/logical/transit/lock_manager.go b/builtin/logical/transit/lock_manager.go new file mode 100644 index 0000000000..6f994b4edb --- /dev/null +++ b/builtin/logical/transit/lock_manager.go @@ -0,0 +1,337 @@ +package transit + +import ( + "encoding/json" + "fmt" + "sync" + + "github.com/hashicorp/vault/logical" +) + +const ( + shared = false + exclusive = true +) + +type lockManager struct { + // A lock for each named key + locks map[string]*sync.RWMutex + + // A mutex for the map itself + lockMutex sync.RWMutex + + // If caching is enabled, the map of name to in-memory policy cache + cache map[string]*Policy + + // Used for global locking, and as the cache map mutex + globalMutex sync.RWMutex +} + +func newLockManager(cacheDisabled bool) *lockManager { + lm := &lockManager{ + locks: map[string]*sync.RWMutex{}, + } + if !cacheDisabled { + lm.cache = map[string]*Policy{} + } + return lm +} + +func (lm *lockManager) CacheActive() bool { + return lm.cache != nil +} + +func (lm *lockManager) LockAll(name string) { + lm.globalMutex.Lock() + lm.LockPolicy(name, exclusive) +} + +func (lm *lockManager) UnlockAll(name string) { + lm.UnlockPolicy(name, exclusive) + lm.globalMutex.Unlock() +} + +func (lm *lockManager) LockPolicy(name string, writeLock bool) { + lm.lockMutex.RLock() + lock := lm.locks[name] + if lock != nil { + // We want to give this up before locking the lock, but it's safe -- + // the only time we ever write to a value in this map is the first time + // we access the value, so it won't be changing out from under us + lm.lockMutex.RUnlock() + if writeLock { + lock.Lock() + } else { + lock.RLock() + } + return + } + + lm.lockMutex.RUnlock() + lm.lockMutex.Lock() + + // Don't defer the unlock call because if we get a valid lock below we want + // to release the lock mutex right away to avoid the possibility of + // deadlock by trying to grab the second lock + + // Check to make sure it hasn't been created since + lock = lm.locks[name] + if lock != nil { + lm.lockMutex.Unlock() + if writeLock { + lock.Lock() + } else { + lock.RLock() + } + return + } + + lock = &sync.RWMutex{} + lm.locks[name] = lock + lm.lockMutex.Unlock() + if writeLock { + lock.Lock() + } else { + lock.RLock() + } +} + +func (lm *lockManager) UnlockPolicy(name string, writeLock bool) { + lm.lockMutex.RLock() + lock := lm.locks[name] + lm.lockMutex.RUnlock() + + if writeLock { + lock.Unlock() + } else { + lock.RUnlock() + } +} + +func (lm *lockManager) GetPolicy(storage logical.Storage, name string) (*Policy, bool, error) { + p, lt, _, err := lm.getPolicyCommon(storage, name, false, false) + return p, lt, err +} + +func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, derived bool) (*Policy, bool, bool, error) { + return lm.getPolicyCommon(storage, name, true, derived) +} + +// When the function returns, a lock will be held on the policy if err == nil. +// The type of lock will be indicated by the return value. It is the caller's +// responsibility to unlock. +func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, upsert, derived bool) (p *Policy, lockType bool, upserted bool, err error) { + // If we are using a cache, lock it now to avoid having to do really + // complicated lock juggling as we call various functions. We'll also defer + // the store into the cache. + lockType = shared + lm.LockPolicy(name, shared) + + if lm.CacheActive() { + lm.globalMutex.RLock() + p = lm.cache[name] + if p != nil { + defer lm.globalMutex.RUnlock() + return + } + lm.globalMutex.RUnlock() + + // When we return, since we didn't have the policy in the cache, if + // there was no error, write the value in. + defer func() { + if err == nil { + lm.globalMutex.Lock() + defer lm.globalMutex.Unlock() + // Make sure a policy didn't appear + exp := lm.cache[name] + if exp != nil { + p = exp + return + } + + lm.cache[name] = p + } + }() + } + + p, err = lm.getStoredPolicy(storage, name) + if err != nil { + defer lm.UnlockPolicy(name, shared) + return + } + + if p == nil { + if !upsert { + defer lm.UnlockPolicy(name, shared) + return + } + + // Get an exlusive lock; on success, check again to ensure that no + // policy exists. Note that if we are using a cache we will already be + // serializing this entire code path and it's currently the only one + // that generates policies, so we don't need to check the cache here; + // simply checking the disk again is sufficient. + lm.UnlockPolicy(name, shared) + lockType = exclusive + lm.LockPolicy(name, exclusive) + + p, err = lm.getStoredPolicy(storage, name) + if err != nil { + defer lm.UnlockPolicy(name, exclusive) + return + } + if p != nil { + return + } + + upserted = true + + p = &Policy{ + Name: name, + CipherMode: "aes-gcm", + Derived: derived, + } + if derived { + p.KDFMode = kdfMode + } + + err = p.rotate(storage) + if err != nil { + defer lm.UnlockPolicy(name, exclusive) + p = nil + } + + // We don't need to worry about upgrading since it will be a new policy + return + } + + if p.needsUpgrade() { + lm.UnlockPolicy(name, shared) + lockType = exclusive + lm.LockPolicy(name, exclusive) + + // Reload the policy with the write lock to ensure we still need the upgrade + p, err = lm.getStoredPolicy(storage, name) + if err != nil { + defer lm.UnlockPolicy(name, exclusive) + return + } + if p == nil { + defer lm.UnlockPolicy(name, exclusive) + err = fmt.Errorf("error reloading policy for upgrade") + return + } + + if !p.needsUpgrade() { + // Already happened, return the newly loaded policy + return + } + + err = p.upgrade(storage) + if err != nil { + defer lm.UnlockPolicy(name, exclusive) + } + } + + return +} + +func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error { + lm.LockAll(name) + defer lm.UnlockAll(name) + + var p *Policy + var err error + + if lm.CacheActive() { + p = lm.cache[name] + if p == nil { + return fmt.Errorf("could not delete policy; not found") + } + } else { + p, err = lm.getStoredPolicy(storage, name) + if err != nil { + return err + } + if p == nil { + return fmt.Errorf("could not delete policy; not found") + } + } + + if !p.DeletionAllowed { + return fmt.Errorf("deletion is not allowed for this policy") + } + + err = storage.Delete("policy/" + name) + if err != nil { + return fmt.Errorf("error deleting policy %s: %s", name, err) + } + + err = storage.Delete("archive/" + name) + if err != nil { + return fmt.Errorf("error deleting archive %s: %s", name, err) + } + + if lm.CacheActive() { + delete(lm.cache, name) + } + + return nil +} + +// When this function returns it's the responsibility of the caller to call UnlockAll if err is not nil +func (lm *lockManager) RefreshPolicy(storage logical.Storage, name string) (p *Policy, err error) { + lm.LockPolicy(name, exclusive) + + if lm.CacheActive() { + p = lm.cache[name] + if p != nil { + return + } + err = fmt.Errorf("could not refresh policy; not found") + defer lm.UnlockPolicy(name, exclusive) + return + } + + p, err = lm.getStoredPolicy(storage, name) + if err != nil { + defer lm.UnlockPolicy(name, exclusive) + return + } + + if p == nil { + err = fmt.Errorf("could not refresh policy; not found") + defer lm.UnlockPolicy(name, exclusive) + } + + if p.needsUpgrade() { + err = p.upgrade(storage) + if err != nil { + defer lm.UnlockPolicy(name, exclusive) + } + } + + return +} + +func (lm *lockManager) getStoredPolicy(storage logical.Storage, name string) (*Policy, error) { + // Check if the policy already exists + raw, err := storage.Get("policy/" + name) + if err != nil { + return nil, err + } + if raw == nil { + return nil, nil + } + + // Decode the policy + policy := &Policy{ + Keys: KeyEntryMap{}, + } + err = json.Unmarshal(raw.Value, policy) + if err != nil { + return nil, err + } + + return policy, nil +} diff --git a/builtin/logical/transit/path_config.go b/builtin/logical/transit/path_config.go index 7395fbe106..eabb3603f6 100644 --- a/builtin/logical/transit/path_config.go +++ b/builtin/logical/transit/path_config.go @@ -41,49 +41,38 @@ func (b *backend) pathConfigWrite( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - // Check if the policy already exists - lp, err := b.policies.getPolicy(req.Storage, name) + // Check if the policy already exists before we lock everything + p, lockType, err := b.lm.GetPolicy(req.Storage, name) if err != nil { return nil, err } - if lp == nil { + if p == nil { return logical.ErrorResponse( fmt.Sprintf("no existing key named %s could be found", name)), logical.ErrInvalidRequest } - // Store some values so we can detect if the policy changed after locking - lp.RLock() - currDeletionAllowed := lp.Policy().DeletionAllowed - currMinDecryptionVersion := lp.Policy().MinDecryptionVersion - lp.RUnlock() + // Store some values so we can detect a change when we lock everything + currDeletionAllowed := p.DeletionAllowed + currMinDecryptionVersion := p.MinDecryptionVersion - // Hold both locks since we want to ensure the policy doesn't change from - // underneath us - b.policies.Lock() - defer b.policies.Unlock() - lp.Lock() - defer lp.Unlock() + b.lm.UnlockPolicy(name, lockType) // Refresh in case it's changed since before we grabbed the lock - lp, err = b.policies.refreshPolicy(req.Storage, name) + p, err = b.lm.RefreshPolicy(req.Storage, name) if err != nil { return nil, err } - if lp == nil { + if p == nil { return nil, fmt.Errorf("error finding key %s after locking for changes", name) } - - // Verify if wasn't deleted before we grabbed the lock - if lp.Policy() == nil { - return nil, fmt.Errorf("no existing key named %s could be found", name) - } + defer b.lm.UnlockPolicy(name, exclusive) resp := &logical.Response{} // Check for anything to have been updated since we got the write lock - if currDeletionAllowed != lp.Policy().DeletionAllowed || - currMinDecryptionVersion != lp.Policy().MinDecryptionVersion { + if currDeletionAllowed != p.DeletionAllowed || + currMinDecryptionVersion != p.MinDecryptionVersion { resp.AddWarning("key configuration has changed since this endpoint was called, not updating") return resp, nil } @@ -104,12 +93,12 @@ func (b *backend) pathConfigWrite( } if minDecryptionVersion > 0 && - minDecryptionVersion != lp.Policy().MinDecryptionVersion { - if minDecryptionVersion > lp.Policy().LatestVersion { + minDecryptionVersion != p.MinDecryptionVersion { + if minDecryptionVersion > p.LatestVersion { return logical.ErrorResponse( - fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, lp.Policy().LatestVersion)), nil + fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, p.LatestVersion)), nil } - lp.Policy().MinDecryptionVersion = minDecryptionVersion + p.MinDecryptionVersion = minDecryptionVersion persistNeeded = true } } @@ -117,8 +106,8 @@ func (b *backend) pathConfigWrite( allowDeletionInt, ok := d.GetOk("deletion_allowed") if ok { allowDeletion := allowDeletionInt.(bool) - if allowDeletion != lp.Policy().DeletionAllowed { - lp.Policy().DeletionAllowed = allowDeletion + if allowDeletion != p.DeletionAllowed { + p.DeletionAllowed = allowDeletion persistNeeded = true } } @@ -126,8 +115,8 @@ func (b *backend) pathConfigWrite( // Add this as a guard here before persisting since we now require the min // decryption version to start at 1; even if it's not explicitly set here, // force the upgrade - if lp.Policy().MinDecryptionVersion == 0 { - lp.Policy().MinDecryptionVersion = 1 + if p.MinDecryptionVersion == 0 { + p.MinDecryptionVersion = 1 persistNeeded = true } @@ -135,7 +124,7 @@ func (b *backend) pathConfigWrite( return nil, nil } - return resp, lp.Policy().Persist(req.Storage) + return resp, p.Persist(req.Storage) } const pathConfigHelpSyn = `Configure a named encryption key` diff --git a/builtin/logical/transit/path_datakey.go b/builtin/logical/transit/path_datakey.go index 9e64836730..bd6fca6517 100644 --- a/builtin/logical/transit/path_datakey.go +++ b/builtin/logical/transit/path_datakey.go @@ -73,23 +73,14 @@ func (b *backend) pathDatakeyWrite( } // Get the policy - lp, err := b.policies.getPolicy(req.Storage, name) + p, lockType, err := b.lm.GetPolicy(req.Storage, name) if err != nil { return nil, err } - - // Error if invalid policy - if lp == nil { + if p == nil { return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - - lp.RLock() - defer lp.RUnlock() - - // Verify if wasn't deleted before we grabbed the lock - if lp.Policy() == nil { - return nil, fmt.Errorf("no existing policy named %s could be found", name) - } + defer b.lm.UnlockPolicy(name, lockType) newKey := make([]byte, 32) bits := d.Get("bits").(int) @@ -107,7 +98,7 @@ func (b *backend) pathDatakeyWrite( return nil, err } - ciphertext, err := lp.Policy().Encrypt(context, base64.StdEncoding.EncodeToString(newKey)) + ciphertext, err := p.Encrypt(context, base64.StdEncoding.EncodeToString(newKey)) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go index a58709858a..8c5799331f 100644 --- a/builtin/logical/transit/path_decrypt.go +++ b/builtin/logical/transit/path_decrypt.go @@ -58,25 +58,17 @@ func (b *backend) pathDecryptWrite( } // Get the policy - lp, err := b.policies.getPolicy(req.Storage, name) + p, lockType, err := b.lm.GetPolicy(req.Storage, name) if err != nil { return nil, err } - - // Error if invalid policy - if lp == nil { + if p == nil { return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - lp.RLock() - defer lp.RUnlock() + defer b.lm.UnlockPolicy(name, lockType) - // Verify if wasn't deleted before we grabbed the lock - if lp.Policy() == nil { - return nil, fmt.Errorf("no existing policy named %s could be found", name) - } - - plaintext, err := lp.Policy().Decrypt(context, ciphertext) + plaintext, err := p.Decrypt(context, ciphertext) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index 8c4cc91e7b..bba7667972 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -44,12 +44,14 @@ func (b *backend) pathEncrypt() *framework.Path { func (b *backend) pathEncryptExistenceCheck( req *logical.Request, d *framework.FieldData) (bool, error) { name := d.Get("name").(string) - lp, err := b.policies.getPolicy(req.Storage, name) + p, lockType, err := b.lm.GetPolicy(req.Storage, name) if err != nil { return false, err } - - return lp != nil, nil + if p != nil { + defer b.lm.UnlockPolicy(name, lockType) + } + return p != nil, nil } func (b *backend) pathEncryptWrite( @@ -63,8 +65,8 @@ func (b *backend) pathEncryptWrite( // Decode the context if any contextRaw := d.Get("context").(string) var context []byte + var err error if len(contextRaw) != 0 { - var err error context, err = base64.StdEncoding.DecodeString(contextRaw) if err != nil { return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest @@ -72,43 +74,23 @@ func (b *backend) pathEncryptWrite( } // Get the policy - lp, err := b.policies.getPolicy(req.Storage, name) + var p *Policy + var lockType bool + var upserted bool + if req.Operation == logical.CreateOperation { + p, lockType, upserted, err = b.lm.GetPolicyUpsert(req.Storage, name, len(context) != 0) + } else { + p, lockType, err = b.lm.GetPolicy(req.Storage, name) + } if err != nil { return nil, err } - - // Error or upsert if invalid policy - if lp == nil { - if req.Operation != logical.CreateOperation { - return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest - } - - // Get a write lock - b.policies.Lock() - - isDerived := len(context) != 0 - - // This also checks to make sure one hasn't been created since we grabbed the write lock - lp, err = b.policies.generatePolicy(req.Storage, name, isDerived) - // If the error is that the policy has been created in the interim we - // will get the policy back, so only consider it an error if err is not - // nil and we do not get a policy back - if err != nil && lp != nil { - b.policies.Unlock() - return nil, err - } - b.policies.Unlock() + if p == nil { + return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } + defer b.lm.UnlockPolicy(name, lockType) - lp.RLock() - defer lp.RUnlock() - - // Verify if wasn't deleted before we grabbed the lock - if lp.Policy() == nil { - return nil, fmt.Errorf("no existing policy named %s could be found", name) - } - - ciphertext, err := lp.Policy().Encrypt(context, value) + ciphertext, err := p.Encrypt(context, value) if err != nil { switch err.(type) { case certutil.UserError: @@ -130,6 +112,11 @@ func (b *backend) pathEncryptWrite( "ciphertext": ciphertext, }, } + + if req.Operation == logical.CreateOperation && !upserted { + resp.AddWarning("Attempted creation of the key during the encrypt operation, but it was created beforehand") + } + return resp, nil } diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index 32cbad4857..956b295d15 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -36,55 +36,58 @@ func (b *backend) pathKeys() *framework.Path { func (b *backend) pathPolicyWrite( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - // Grab a write lock right off the bat - b.policies.Lock() - defer b.policies.Unlock() - name := d.Get("name").(string) derived := d.Get("derived").(bool) - // Generate the policy; this will also check if it exists for safety - _, err := b.policies.generatePolicy(req.Storage, name, derived) - return nil, err + p, lockType, upserted, err := b.lm.GetPolicyUpsert(req.Storage, name, derived) + if err != nil { + return nil, err + } + if p == nil { + return nil, fmt.Errorf("error generating key: returned policy was nil") + } + + defer b.lm.UnlockPolicy(name, lockType) + + resp := &logical.Response{} + if !upserted { + resp.AddWarning(fmt.Sprintf("key %s already existed", name)) + } + + return nil, nil } func (b *backend) pathPolicyRead( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - lp, err := b.policies.getPolicy(req.Storage, name) + p, lockType, err := b.lm.GetPolicy(req.Storage, name) if err != nil { return nil, err } - if lp == nil { + if p == nil { return nil, nil } - lp.RLock() - defer lp.RUnlock() - - // Verify if wasn't deleted before we grabbed the lock - if lp.Policy() == nil { - return nil, fmt.Errorf("no existing policy named %s could be found", name) - } + defer b.lm.UnlockPolicy(name, lockType) // Return the response resp := &logical.Response{ Data: map[string]interface{}{ - "name": lp.Policy().Name, - "cipher_mode": lp.Policy().CipherMode, - "derived": lp.Policy().Derived, - "deletion_allowed": lp.Policy().DeletionAllowed, - "min_decryption_version": lp.Policy().MinDecryptionVersion, - "latest_version": lp.Policy().LatestVersion, + "name": p.Name, + "cipher_mode": p.CipherMode, + "derived": p.Derived, + "deletion_allowed": p.DeletionAllowed, + "min_decryption_version": p.MinDecryptionVersion, + "latest_version": p.LatestVersion, }, } - if lp.Policy().Derived { - resp.Data["kdf_mode"] = lp.Policy().KDFMode + if p.Derived { + resp.Data["kdf_mode"] = p.KDFMode } retKeys := map[string]int64{} - for k, v := range lp.Policy().Keys { + for k, v := range p.Keys { retKeys[strconv.Itoa(k)] = v.CreationTime } resp.Data["keys"] = retKeys @@ -96,33 +99,17 @@ func (b *backend) pathPolicyDelete( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - // Some sanity checking - lp, err := b.policies.getPolicy(req.Storage, name) + // Some sanity checking before we lock it all in the DeletePolicy method + p, lockType, err := b.lm.GetPolicy(req.Storage, name) if err != nil { return logical.ErrorResponse(fmt.Sprintf("error looking up policy %s, error is %s", name, err)), err } - if lp == nil { + if p == nil { return logical.ErrorResponse(fmt.Sprintf("no such key %s", name)), logical.ErrInvalidRequest } + b.lm.UnlockPolicy(name, lockType) - // Hold both locks since we'll be affecting both the cache (if it exists) - // and the locking policy itself - b.policies.Lock() - defer b.policies.Unlock() - lp.Lock() - defer lp.Unlock() - - // Make sure that we have up-to-date values since deletePolicy will check - // things like whether deletion is allowed - lp, err = b.policies.refreshPolicy(req.Storage, name) - if err != nil { - return nil, err - } - if lp == nil { - return nil, fmt.Errorf("error finding key %s after locking for deletion", name) - } - - err = b.policies.deletePolicy(req.Storage, lp, name) + err = b.lm.DeletePolicy(req.Storage, name) if err != nil { return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err } diff --git a/builtin/logical/transit/path_rewrap.go b/builtin/logical/transit/path_rewrap.go index 2233689b3b..61e3f607e8 100644 --- a/builtin/logical/transit/path_rewrap.go +++ b/builtin/logical/transit/path_rewrap.go @@ -59,25 +59,18 @@ func (b *backend) pathRewrapWrite( } // Get the policy - lp, err := b.policies.getPolicy(req.Storage, name) + p, lockType, err := b.lm.GetPolicy(req.Storage, name) if err != nil { return nil, err } - // Error if invalid policy - if lp == nil { + if p == nil { return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - lp.RLock() - defer lp.RUnlock() + defer b.lm.UnlockPolicy(name, lockType) - // Verify if wasn't deleted before we grabbed the lock - if lp.Policy() == nil { - return nil, fmt.Errorf("no existing policy named %s could be found", name) - } - - plaintext, err := lp.Policy().Decrypt(context, value) + plaintext, err := p.Decrypt(context, value) if err != nil { switch err.(type) { case certutil.UserError: @@ -93,7 +86,7 @@ func (b *backend) pathRewrapWrite( return nil, fmt.Errorf("empty plaintext returned during rewrap") } - ciphertext, err := lp.Policy().Encrypt(context, plaintext) + ciphertext, err := p.Encrypt(context, plaintext) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_rotate.go b/builtin/logical/transit/path_rotate.go index 93e1951a91..442646ef8c 100644 --- a/builtin/logical/transit/path_rotate.go +++ b/builtin/logical/transit/path_rotate.go @@ -31,48 +31,38 @@ func (b *backend) pathRotateWrite( name := d.Get("name").(string) // Get the policy - lp, err := b.policies.getPolicy(req.Storage, name) + p, lockType, err := b.lm.GetPolicy(req.Storage, name) if err != nil { return nil, err } - - // Error if invalid policy - if lp == nil { + if p == nil { return logical.ErrorResponse("key not found"), logical.ErrInvalidRequest } // Store so we can detect later if this has changed out from under us - keyVersion := lp.Policy().LatestVersion + keyVersion := p.LatestVersion - // lock the policies object so we can refresh - b.policies.Lock() - defer b.policies.Unlock() - lp.Lock() - defer lp.Unlock() + b.lm.UnlockPolicy(name, lockType) // Refresh in case it's changed since before we grabbed the lock - lp, err = b.policies.refreshPolicy(req.Storage, name) + p, err = b.lm.RefreshPolicy(req.Storage, name) if err != nil { return nil, err } - if lp == nil { + if p == nil { return nil, fmt.Errorf("error finding key %s after locking for changes", name) } - - // Verify if wasn't deleted before we grabbed the lock - if lp.Policy() == nil { - return nil, fmt.Errorf("no existing key named %s could be found", name) - } + defer b.lm.UnlockPolicy(name, exclusive) // Make sure that the policy hasn't been rotated simultaneously - if keyVersion != lp.Policy().LatestVersion { + if keyVersion != p.LatestVersion { resp := &logical.Response{} resp.AddWarning("key has been rotated since this endpoint was called; did not perform rotation") return resp, nil } // Rotate the policy - err = lp.Policy().rotate(req.Storage) + err = p.rotate(req.Storage) return nil, err } diff --git a/builtin/logical/transit/policy.go b/builtin/logical/transit/policy.go index ad47af7547..c1b09e2d9a 100644 --- a/builtin/logical/transit/policy.go +++ b/builtin/logical/transit/policy.go @@ -235,6 +235,67 @@ func (p *Policy) Serialize() ([]byte, error) { return json.Marshal(p) } +func (p *Policy) needsUpgrade() bool { + // Ensure we've moved from Key -> Keys + if p.Key != nil && len(p.Key) > 0 { + return true + } + + // With archiving, past assumptions about the length of the keys map are no longer valid + if p.LatestVersion == 0 && len(p.Keys) != 0 { + return true + } + + // We disallow setting the version to 0, since they start at 1 since moving + // to rotate-able keys, so update if it's set to 0 + if p.MinDecryptionVersion == 0 { + return true + } + + // On first load after an upgrade, copy keys to the archive + if p.ArchiveVersion == 0 { + return true + } + + return false +} + +func (p *Policy) upgrade(storage logical.Storage) error { + persistNeeded := false + // Ensure we've moved from Key -> Keys + if p.Key != nil && len(p.Key) > 0 { + p.migrateKeyToKeysMap() + persistNeeded = true + } + + // With archiving, past assumptions about the length of the keys map are no longer valid + if p.LatestVersion == 0 && len(p.Keys) != 0 { + p.LatestVersion = len(p.Keys) + persistNeeded = true + } + + // We disallow setting the version to 0, since they start at 1 since moving + // to rotate-able keys, so update if it's set to 0 + if p.MinDecryptionVersion == 0 { + p.MinDecryptionVersion = 1 + persistNeeded = true + } + + // On first load after an upgrade, copy keys to the archive + if p.ArchiveVersion == 0 { + persistNeeded = true + } + + if persistNeeded { + err := p.Persist(storage) + if err != nil { + return err + } + } + + return nil +} + // DeriveKey is used to derive the encryption key that should // be used depending on the policy. If derivation is disabled the // raw key is used and no context is required, otherwise the KDF diff --git a/builtin/logical/transit/policy_crud.go b/builtin/logical/transit/policy_crud.go deleted file mode 100644 index 46472de962..0000000000 --- a/builtin/logical/transit/policy_crud.go +++ /dev/null @@ -1,188 +0,0 @@ -package transit - -import ( - "encoding/json" - "fmt" - "sync" - - "github.com/hashicorp/vault/logical" -) - -type lockingPolicy interface { - Lock() - RLock() - Unlock() - RUnlock() - Policy() *Policy - SetPolicy(*Policy) -} - -type policyCRUD interface { - // getPolicy returns a lockingPolicy. It performs its own locking according - // to implementation. - getPolicy(storage logical.Storage, name string) (lockingPolicy, error) - - // refreshPolicy returns a lockingPolicy. It does not perform its own - // locking; a write lock must be held before calling. - refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error) - - // generatePolicy generates and returns a lockingPolicy. A write lock must - // be held before calling. - generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error) - - // deletePolicy deletes a lockingPolicy. A write lock must be held on both - // the CRUD implementation and the lockingPolicy before calling. - deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error - - // These are generally satisfied by embedded mutexes in the implementing struct - Lock() - RLock() - Unlock() - RUnlock() -} - -// The mutex is kept separate from the struct since we may set it to its own -// mutex (if the object is shared) or a shared mutex (if the object isn't -// shared and only the locking is) -type mutexLockingPolicy struct { - mutex *sync.RWMutex - policy *Policy -} - -func (m *mutexLockingPolicy) Lock() { - m.mutex.Lock() -} - -func (m *mutexLockingPolicy) RLock() { - m.mutex.RLock() -} - -func (m *mutexLockingPolicy) Unlock() { - m.mutex.Unlock() -} - -func (m *mutexLockingPolicy) RUnlock() { - m.mutex.RUnlock() -} - -func (m *mutexLockingPolicy) Policy() *Policy { - return m.policy -} - -func (m *mutexLockingPolicy) SetPolicy(p *Policy) { - m.policy = p -} - -// fetchPolicyFromStorage fetches the policy from backend storage. The caller -// should hold the write lock when calling this, to handle upgrades. -func fetchPolicyFromStorage(storage logical.Storage, name string) (*Policy, error) { - // Check if the policy already exists - raw, err := storage.Get("policy/" + name) - if err != nil { - return nil, err - } - if raw == nil { - return nil, nil - } - - // Decode the policy - policy := &Policy{ - Keys: KeyEntryMap{}, - } - err = json.Unmarshal(raw.Value, policy) - if err != nil { - return nil, err - } - - persistNeeded := false - // Ensure we've moved from Key -> Keys - if policy.Key != nil && len(policy.Key) > 0 { - policy.migrateKeyToKeysMap() - persistNeeded = true - } - - // With archiving, past assumptions about the length of the keys map are no longer valid - if policy.LatestVersion == 0 && len(policy.Keys) != 0 { - policy.LatestVersion = len(policy.Keys) - persistNeeded = true - } - - // We disallow setting the version to 0, since they start at 1 since moving - // to rotate-able keys, so update if it's set to 0 - if policy.MinDecryptionVersion == 0 { - policy.MinDecryptionVersion = 1 - persistNeeded = true - } - - // On first load after an upgrade, copy keys to the archive - if policy.ArchiveVersion == 0 { - persistNeeded = true - } - - if persistNeeded { - err = policy.Persist(storage) - if err != nil { - return nil, err - } - } - - return policy, nil -} - -// generatePolicyCommon is used to create a new named policy with a randomly -// generated key. The caller should have a write lock prior to calling this. -func generatePolicyCommon(p policyCRUD, storage logical.Storage, name string, derived bool) (*Policy, error) { - // Make sure this doesn't exist in case it was created before we got the write lock - policy, err := fetchPolicyFromStorage(storage, name) - if err != nil { - return nil, err - } - if policy != nil { - return policy, nil - } - - // Create the policy object - policy = &Policy{ - Name: name, - CipherMode: "aes-gcm", - Derived: derived, - } - if derived { - policy.KDFMode = kdfMode - } - - err = policy.rotate(storage) - if err != nil { - return nil, err - } - - return policy, err -} - -// deletePolicyCommon deletes a policy. The caller should hold the write lock -// for both the policy and lockingPolicy prior to calling this. -func deletePolicyCommon(p policyCRUD, lp lockingPolicy, storage logical.Storage, name string) error { - if lp.Policy() == nil { - // This got deleted before we grabbed the lock - return fmt.Errorf("policy already deleted") - } - - // Verify this hasn't changed - if !lp.Policy().DeletionAllowed { - return fmt.Errorf("deletion not allowed for policy %s", name) - } - - err := storage.Delete("policy/" + name) - if err != nil { - return fmt.Errorf("error deleting policy %s: %s", name, err) - } - - err = storage.Delete("archive/" + name) - if err != nil { - return fmt.Errorf("error deleting archive %s: %s", name, err) - } - - lp.SetPolicy(nil) - - return nil -} diff --git a/builtin/logical/transit/policy_test.go b/builtin/logical/transit/policy_test.go index af49c27d2c..b6d0acf418 100644 --- a/builtin/logical/transit/policy_test.go +++ b/builtin/logical/transit/policy_test.go @@ -16,48 +16,51 @@ func resetKeysArchive() { } func Test_KeyUpgrade(t *testing.T) { - testKeyUpgradeCommon(t, newSimplePolicyCRUD()) - testKeyUpgradeCommon(t, newCachingPolicyCRUD()) + testKeyUpgradeCommon(t, newLockManager(false)) + testKeyUpgradeCommon(t, newLockManager(true)) } -func testKeyUpgradeCommon(t *testing.T, policies policyCRUD) { +func testKeyUpgradeCommon(t *testing.T, lm *lockManager) { storage := &logical.InmemStorage{} - lp, err := policies.generatePolicy(storage, "test", false) + p, lockType, upserted, err := lm.GetPolicyUpsert(storage, "test", false) if err != nil { t.Fatal(err) } - if lp == nil { - t.Fatal("nil lockingPolicy") + if p == nil { + t.Fatal("nil policy") + } + defer lm.UnlockPolicy("test", lockType) + + if !upserted { + t.Fatal("expected an upsert") + } + if lockType != exclusive { + t.Fatal("expected an exclusive lock") } - policy := lp.Policy() - if policy == nil { - t.Fatal("nil policy in lockingPolicy") - } + testBytes := make([]byte, len(p.Keys[1].Key)) + copy(testBytes, p.Keys[1].Key) - testBytes := make([]byte, len(policy.Keys[1].Key)) - copy(testBytes, policy.Keys[1].Key) - - policy.Key = policy.Keys[1].Key - policy.Keys = nil - policy.migrateKeyToKeysMap() - if policy.Key != nil { + p.Key = p.Keys[1].Key + p.Keys = nil + p.migrateKeyToKeysMap() + if p.Key != nil { t.Fatal("policy.Key is not nil") } - if len(policy.Keys) != 1 { + if len(p.Keys) != 1 { t.Fatal("policy.Keys is the wrong size") } - if !reflect.DeepEqual(testBytes, policy.Keys[1].Key) { + if !reflect.DeepEqual(testBytes, p.Keys[1].Key) { t.Fatal("key mismatch") } } func Test_ArchivingUpgrade(t *testing.T) { - testArchivingUpgradeCommon(t, newSimplePolicyCRUD()) - testArchivingUpgradeCommon(t, newCachingPolicyCRUD()) + testArchivingUpgradeCommon(t, newLockManager(false)) + testArchivingUpgradeCommon(t, newLockManager(true)) } -func testArchivingUpgradeCommon(t *testing.T, policies policyCRUD) { +func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) { resetKeysArchive() // First, we generate a policy and rotate it a number of times. Each time @@ -67,30 +70,26 @@ func testArchivingUpgradeCommon(t *testing.T, policies policyCRUD) { storage := &logical.InmemStorage{} - lp, err := policies.generatePolicy(storage, "test", false) + p, lockType, _, err := lm.GetPolicyUpsert(storage, "test", false) if err != nil { t.Fatal(err) } - if lp == nil { - t.Fatal("nil lockingPolicy") - } - - policy := lp.Policy() - if policy == nil { - t.Fatal("nil policy in lockingPolicy") + if p == nil { + t.Fatal("nil policy") } + lm.UnlockPolicy("test", lockType) // Store the initial key in the archive - keysArchive = append(keysArchive, policy.Keys[1]) - checkKeys(t, policy, storage, "initial", 1, 1, 1) + keysArchive = append(keysArchive, p.Keys[1]) + checkKeys(t, p, storage, "initial", 1, 1, 1) for i := 2; i <= 10; i++ { - err = policy.rotate(storage) + err = p.rotate(storage) if err != nil { t.Fatal(err) } - keysArchive = append(keysArchive, policy.Keys[i]) - checkKeys(t, policy, storage, "rotate", i, i, i) + keysArchive = append(keysArchive, p.Keys[i]) + checkKeys(t, p, storage, "rotate", i, i, i) } // Now, wipe the archive and set the archive version to zero @@ -98,53 +97,49 @@ func testArchivingUpgradeCommon(t *testing.T, policies policyCRUD) { if err != nil { t.Fatal(err) } - policy.ArchiveVersion = 0 + p.ArchiveVersion = 0 // Store it, but without calling persist, so we don't trigger // handleArchiving() - buf, err := policy.Serialize() + buf, err := p.Serialize() if err != nil { t.Fatal(err) } // Write the policy into storage err = storage.Put(&logical.StorageEntry{ - Key: "policy/" + policy.Name, + Key: "policy/" + p.Name, Value: buf, }) if err != nil { t.Fatal(err) } - // If it's a caching CRUD, expire from the cache since we modified it + // If we're caching, expire from the cache since we modified it // under-the-hood - if cachingCRUD, ok := policies.(*cachingPolicyCRUD); ok { - delete(cachingCRUD.cache, "test") + if lm.CacheActive() { + delete(lm.cache, "test") } // Now get the policy again; the upgrade should happen automatically - lp, err = policies.getPolicy(storage, "test") + p, lockType, err = lm.GetPolicy(storage, "test") if err != nil { t.Fatal(err) } - if lp == nil { + if p == nil { t.Fatal("nil lockingPolicy") } + lm.UnlockPolicy("test", lockType) - policy = lp.Policy() - if policy == nil { - t.Fatal("nil policy in lockingPolicy") - } - - checkKeys(t, policy, storage, "upgrade", 10, 10, 10) + checkKeys(t, p, storage, "upgrade", 10, 10, 10) } func Test_Archiving(t *testing.T) { - testArchivingCommon(t, newSimplePolicyCRUD()) - testArchivingCommon(t, newCachingPolicyCRUD()) + testArchivingCommon(t, newLockManager(false)) + testArchivingCommon(t, newLockManager(true)) } -func testArchivingCommon(t *testing.T, policies policyCRUD) { +func testArchivingCommon(t *testing.T, lm *lockManager) { resetKeysArchive() // First, we generate a policy and rotate it a number of times. Each time @@ -154,37 +149,33 @@ func testArchivingCommon(t *testing.T, policies policyCRUD) { storage := &logical.InmemStorage{} - lp, err := policies.generatePolicy(storage, "test", false) + p, lockType, _, err := lm.GetPolicyUpsert(storage, "test", false) if err != nil { t.Fatal(err) } - if lp == nil { - t.Fatal("nil lockingPolicy") - } - - policy := lp.Policy() - if policy == nil { - t.Fatal("nil policy in lockingPolicy") + if p == nil { + t.Fatal("nil policy") } + defer lm.UnlockPolicy("test", lockType) // Store the initial key in the archive - keysArchive = append(keysArchive, policy.Keys[1]) - checkKeys(t, policy, storage, "initial", 1, 1, 1) + keysArchive = append(keysArchive, p.Keys[1]) + checkKeys(t, p, storage, "initial", 1, 1, 1) for i := 2; i <= 10; i++ { - err = policy.rotate(storage) + err = p.rotate(storage) if err != nil { t.Fatal(err) } - keysArchive = append(keysArchive, policy.Keys[i]) - checkKeys(t, policy, storage, "rotate", i, i, i) + keysArchive = append(keysArchive, p.Keys[i]) + checkKeys(t, p, storage, "rotate", i, i, i) } // Move the min decryption version up for i := 1; i <= 10; i++ { - policy.MinDecryptionVersion = i + p.MinDecryptionVersion = i - err = policy.Persist(storage) + err = p.Persist(storage) if err != nil { t.Fatal(err) } @@ -196,14 +187,14 @@ func testArchivingCommon(t *testing.T, policies policyCRUD) { // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min // decryption version plus 1 (the min decryption version key // itself) - checkKeys(t, policy, storage, "minadd", 10, 10, policy.LatestVersion-policy.MinDecryptionVersion+1) + checkKeys(t, p, storage, "minadd", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1) } // Move the min decryption version down for i := 10; i >= 1; i-- { - policy.MinDecryptionVersion = i + p.MinDecryptionVersion = i - err = policy.Persist(storage) + err = p.Persist(storage) if err != nil { t.Fatal(err) } @@ -215,7 +206,7 @@ func testArchivingCommon(t *testing.T, policies policyCRUD) { // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min // decryption version plus 1 (the min decryption version key // itself) - checkKeys(t, policy, storage, "minsub", 10, 10, policy.LatestVersion-policy.MinDecryptionVersion+1) + checkKeys(t, p, storage, "minsub", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1) } } diff --git a/builtin/logical/transit/simple_crud.go b/builtin/logical/transit/simple_crud.go deleted file mode 100644 index 6538c7eeae..0000000000 --- a/builtin/logical/transit/simple_crud.go +++ /dev/null @@ -1,82 +0,0 @@ -package transit - -import ( - "sync" - - "github.com/hashicorp/vault/logical" -) - -// Directly implements CRUD operations without caching, mapped to the backend, -// but implements shared locking to ensure that we can't overwrite data on the -// backend from multiple operators -type simplePolicyCRUD struct { - sync.RWMutex - locks map[string]*sync.RWMutex -} - -func newSimplePolicyCRUD() *simplePolicyCRUD { - return &simplePolicyCRUD{ - locks: map[string]*sync.RWMutex{}, - } -} - -// The write lock must be held before calling this; for this CRUD type this -// should always be the case, since the only method not requiring a write lock -// when called is getPolicy, and that itself grabs a write lock before calling -// refreshPolicy -func (p *simplePolicyCRUD) ensureLockExists(name string) { - if p.locks[name] == nil { - p.locks[name] = &sync.RWMutex{} - } -} - -// See general comments on the interface method -func (p *simplePolicyCRUD) getPolicy(storage logical.Storage, name string) (lockingPolicy, error) { - // Use a write lock since fetching the policy can cause a need for upgrade persistence - p.Lock() - defer p.Unlock() - - return p.refreshPolicy(storage, name) -} - -// See general comments on the interface method -func (p *simplePolicyCRUD) refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error) { - p.ensureLockExists(name) - - policy, err := fetchPolicyFromStorage(storage, name) - if err != nil { - return nil, err - } - if policy == nil { - return nil, nil - } - - lp := &mutexLockingPolicy{ - policy: policy, - mutex: p.locks[name], - } - - return lp, nil -} - -// See general comments on the interface method -func (p *simplePolicyCRUD) generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error) { - p.ensureLockExists(name) - - policy, err := generatePolicyCommon(p, storage, name, derived) - if err != nil { - return nil, err - } - - lp := &mutexLockingPolicy{ - policy: policy, - mutex: p.locks[name], - } - - return lp, nil -} - -// See general comments on the interface method -func (p *simplePolicyCRUD) deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error { - return deletePolicyCommon(p, lp, storage, name) -} From 5ec40a14f4de81cb73c857ea5aa13ae07159c489 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Tue, 26 Apr 2016 19:30:39 +0000 Subject: [PATCH 07/11] Address review feedback --- builtin/logical/transit/lock_manager.go | 49 ++++++++++++++----------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/builtin/logical/transit/lock_manager.go b/builtin/logical/transit/lock_manager.go index 6f994b4edb..16096a3ac6 100644 --- a/builtin/logical/transit/lock_manager.go +++ b/builtin/logical/transit/lock_manager.go @@ -18,7 +18,7 @@ type lockManager struct { locks map[string]*sync.RWMutex // A mutex for the map itself - lockMutex sync.RWMutex + locksMutex sync.RWMutex // If caching is enabled, the map of name to in-memory policy cache cache map[string]*Policy @@ -52,13 +52,13 @@ func (lm *lockManager) UnlockAll(name string) { } func (lm *lockManager) LockPolicy(name string, writeLock bool) { - lm.lockMutex.RLock() + lm.locksMutex.RLock() lock := lm.locks[name] if lock != nil { // We want to give this up before locking the lock, but it's safe -- // the only time we ever write to a value in this map is the first time // we access the value, so it won't be changing out from under us - lm.lockMutex.RUnlock() + lm.locksMutex.RUnlock() if writeLock { lock.Lock() } else { @@ -67,8 +67,8 @@ func (lm *lockManager) LockPolicy(name string, writeLock bool) { return } - lm.lockMutex.RUnlock() - lm.lockMutex.Lock() + lm.locksMutex.RUnlock() + lm.locksMutex.Lock() // Don't defer the unlock call because if we get a valid lock below we want // to release the lock mutex right away to avoid the possibility of @@ -77,7 +77,7 @@ func (lm *lockManager) LockPolicy(name string, writeLock bool) { // Check to make sure it hasn't been created since lock = lm.locks[name] if lock != nil { - lm.lockMutex.Unlock() + lm.locksMutex.Unlock() if writeLock { lock.Lock() } else { @@ -88,7 +88,7 @@ func (lm *lockManager) LockPolicy(name string, writeLock bool) { lock = &sync.RWMutex{} lm.locks[name] = lock - lm.lockMutex.Unlock() + lm.locksMutex.Unlock() if writeLock { lock.Lock() } else { @@ -97,9 +97,9 @@ func (lm *lockManager) LockPolicy(name string, writeLock bool) { } func (lm *lockManager) UnlockPolicy(name string, writeLock bool) { - lm.lockMutex.RLock() + lm.locksMutex.RLock() lock := lm.locks[name] - lm.lockMutex.RUnlock() + lm.locksMutex.RUnlock() if writeLock { lock.Unlock() @@ -131,7 +131,7 @@ func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, ups lm.globalMutex.RLock() p = lm.cache[name] if p != nil { - defer lm.globalMutex.RUnlock() + lm.globalMutex.RUnlock() return } lm.globalMutex.RUnlock() @@ -139,16 +139,20 @@ func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, ups // When we return, since we didn't have the policy in the cache, if // there was no error, write the value in. defer func() { - if err == nil { - lm.globalMutex.Lock() - defer lm.globalMutex.Unlock() - // Make sure a policy didn't appear - exp := lm.cache[name] - if exp != nil { - p = exp - return - } + lm.globalMutex.Lock() + defer lm.globalMutex.Unlock() + // Make sure a policy didn't appear. If so, it will only be set if + // there was no error, so now just clear the error and return that + // policy. + exp := lm.cache[name] + if exp != nil { + upserted = false + err = nil + p = exp + return + } + if err == nil { lm.cache[name] = p } }() @@ -156,13 +160,13 @@ func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, ups p, err = lm.getStoredPolicy(storage, name) if err != nil { - defer lm.UnlockPolicy(name, shared) + lm.UnlockPolicy(name, shared) return } if p == nil { if !upsert { - defer lm.UnlockPolicy(name, shared) + lm.UnlockPolicy(name, shared) return } @@ -279,7 +283,8 @@ func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error return nil } -// When this function returns it's the responsibility of the caller to call UnlockAll if err is not nil +// When this function returns it's the responsibility of the caller to call +// UnlockPolicy if err is nil and policy is not nil func (lm *lockManager) RefreshPolicy(storage logical.Storage, name string) (p *Policy, err error) { lm.LockPolicy(name, exclusive) From 16267d511564fef582c83f433e3ce692d86028e6 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Tue, 26 Apr 2016 20:57:07 +0000 Subject: [PATCH 08/11] Change use-hint of lockAll and lockPolicy --- builtin/logical/transit/lock_manager.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/builtin/logical/transit/lock_manager.go b/builtin/logical/transit/lock_manager.go index 16096a3ac6..ad53482af5 100644 --- a/builtin/logical/transit/lock_manager.go +++ b/builtin/logical/transit/lock_manager.go @@ -41,9 +41,9 @@ func (lm *lockManager) CacheActive() bool { return lm.cache != nil } -func (lm *lockManager) LockAll(name string) { +func (lm *lockManager) lockAll(name string) { lm.globalMutex.Lock() - lm.LockPolicy(name, exclusive) + lm.lockPolicy(name, exclusive) } func (lm *lockManager) UnlockAll(name string) { @@ -51,7 +51,7 @@ func (lm *lockManager) UnlockAll(name string) { lm.globalMutex.Unlock() } -func (lm *lockManager) LockPolicy(name string, writeLock bool) { +func (lm *lockManager) lockPolicy(name string, writeLock bool) { lm.locksMutex.RLock() lock := lm.locks[name] if lock != nil { @@ -125,7 +125,7 @@ func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, ups // complicated lock juggling as we call various functions. We'll also defer // the store into the cache. lockType = shared - lm.LockPolicy(name, shared) + lm.lockPolicy(name, shared) if lm.CacheActive() { lm.globalMutex.RLock() @@ -177,7 +177,7 @@ func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, ups // simply checking the disk again is sufficient. lm.UnlockPolicy(name, shared) lockType = exclusive - lm.LockPolicy(name, exclusive) + lm.lockPolicy(name, exclusive) p, err = lm.getStoredPolicy(storage, name) if err != nil { @@ -212,7 +212,7 @@ func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, ups if p.needsUpgrade() { lm.UnlockPolicy(name, shared) lockType = exclusive - lm.LockPolicy(name, exclusive) + lm.lockPolicy(name, exclusive) // Reload the policy with the write lock to ensure we still need the upgrade p, err = lm.getStoredPolicy(storage, name) @@ -241,7 +241,7 @@ func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, ups } func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error { - lm.LockAll(name) + lm.lockAll(name) defer lm.UnlockAll(name) var p *Policy @@ -286,7 +286,7 @@ func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error // When this function returns it's the responsibility of the caller to call // UnlockPolicy if err is nil and policy is not nil func (lm *lockManager) RefreshPolicy(storage logical.Storage, name string) (p *Policy, err error) { - lm.LockPolicy(name, exclusive) + lm.lockPolicy(name, exclusive) if lm.CacheActive() { p = lm.cache[name] From bf7ad912e104769d57bac6b75ed580b7f5a9ce1e Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 28 Apr 2016 21:29:51 +0000 Subject: [PATCH 09/11] Remove some deferring --- builtin/logical/transit/lock_manager.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/builtin/logical/transit/lock_manager.go b/builtin/logical/transit/lock_manager.go index ad53482af5..bd817ed3e0 100644 --- a/builtin/logical/transit/lock_manager.go +++ b/builtin/logical/transit/lock_manager.go @@ -217,12 +217,12 @@ func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, ups // Reload the policy with the write lock to ensure we still need the upgrade p, err = lm.getStoredPolicy(storage, name) if err != nil { - defer lm.UnlockPolicy(name, exclusive) + lm.UnlockPolicy(name, exclusive) return } if p == nil { - defer lm.UnlockPolicy(name, exclusive) err = fmt.Errorf("error reloading policy for upgrade") + lm.UnlockPolicy(name, exclusive) return } @@ -300,7 +300,7 @@ func (lm *lockManager) RefreshPolicy(storage logical.Storage, name string) (p *P p, err = lm.getStoredPolicy(storage, name) if err != nil { - defer lm.UnlockPolicy(name, exclusive) + lm.UnlockPolicy(name, exclusive) return } From 027d570f7f07ec51d1ae92604d5b24971dcf15ae Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Mon, 2 May 2016 23:46:39 -0400 Subject: [PATCH 10/11] Massively simplify lock handling based on feedback --- builtin/logical/transit/lock_manager.go | 278 +++++++++++------------- builtin/logical/transit/path_config.go | 28 +-- builtin/logical/transit/path_datakey.go | 6 +- builtin/logical/transit/path_decrypt.go | 7 +- builtin/logical/transit/path_encrypt.go | 19 +- builtin/logical/transit/path_keys.go | 27 +-- builtin/logical/transit/path_rewrap.go | 7 +- builtin/logical/transit/path_rotate.go | 29 +-- 8 files changed, 170 insertions(+), 231 deletions(-) diff --git a/builtin/logical/transit/lock_manager.go b/builtin/logical/transit/lock_manager.go index bd817ed3e0..9f20a423a5 100644 --- a/builtin/logical/transit/lock_manager.go +++ b/builtin/logical/transit/lock_manager.go @@ -2,6 +2,7 @@ package transit import ( "encoding/json" + "errors" "fmt" "sync" @@ -13,6 +14,10 @@ const ( exclusive = true ) +var ( + errNeedExclusiveLock = errors.New("an exclusive lock is needed for this operation") +) + type lockManager struct { // A lock for each named key locks map[string]*sync.RWMutex @@ -24,7 +29,7 @@ type lockManager struct { cache map[string]*Policy // Used for global locking, and as the cache map mutex - globalMutex sync.RWMutex + cacheMutex sync.RWMutex } func newLockManager(cacheDisabled bool) *lockManager { @@ -41,17 +46,7 @@ func (lm *lockManager) CacheActive() bool { return lm.cache != nil } -func (lm *lockManager) lockAll(name string) { - lm.globalMutex.Lock() - lm.lockPolicy(name, exclusive) -} - -func (lm *lockManager) UnlockAll(name string) { - lm.UnlockPolicy(name, exclusive) - lm.globalMutex.Unlock() -} - -func (lm *lockManager) lockPolicy(name string, writeLock bool) { +func (lm *lockManager) policyLock(name string, lockType bool) *sync.RWMutex { lm.locksMutex.RLock() lock := lm.locks[name] if lock != nil { @@ -59,12 +54,12 @@ func (lm *lockManager) lockPolicy(name string, writeLock bool) { // the only time we ever write to a value in this map is the first time // we access the value, so it won't be changing out from under us lm.locksMutex.RUnlock() - if writeLock { + if lockType == exclusive { lock.Lock() } else { lock.RLock() } - return + return lock } lm.locksMutex.RUnlock() @@ -78,117 +73,121 @@ func (lm *lockManager) lockPolicy(name string, writeLock bool) { lock = lm.locks[name] if lock != nil { lm.locksMutex.Unlock() - if writeLock { + if lockType == exclusive { lock.Lock() } else { lock.RLock() } - return + return lock } lock = &sync.RWMutex{} lm.locks[name] = lock lm.locksMutex.Unlock() - if writeLock { + if lockType == exclusive { lock.Lock() } else { lock.RLock() } + + return lock } -func (lm *lockManager) UnlockPolicy(name string, writeLock bool) { - lm.locksMutex.RLock() - lock := lm.locks[name] - lm.locksMutex.RUnlock() - - if writeLock { +func (lm *lockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) { + if lockType == exclusive { lock.Unlock() } else { lock.RUnlock() } } -func (lm *lockManager) GetPolicy(storage logical.Storage, name string) (*Policy, bool, error) { - p, lt, _, err := lm.getPolicyCommon(storage, name, false, false) - return p, lt, err +// Get the policy with a read lock. If we get an error saying an exclusive lock +// is needed (for instance, for an upgrade/migration), give up the read lock, +// call again with an exclusive lock, then swap back out for a read lock. +func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) { + p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, shared) + if err == nil || + (err != nil && err != errNeedExclusiveLock) { + return p, lock, err + } + + // Try again while asking for an exlusive lock + p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, exclusive) + if err != nil || p == nil || lock == nil { + return p, lock, err + } + + lock.Unlock() + + p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, shared) + return p, lock, err } -func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, derived bool) (*Policy, bool, bool, error) { - return lm.getPolicyCommon(storage, name, true, derived) +// Get the policy with an exclusive lock +func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) { + p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, exclusive) + return p, lock, err +} + +// Get the policy with a read lock; if it returns that an exclusive lock is +// needed, retry. If successful, call one more time to get a read lock and +// return the value. +func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, derived bool) (*Policy, *sync.RWMutex, bool, error) { + p, lock, upserted, err := lm.getPolicyCommon(storage, name, true, derived, shared) + if err == nil || + (err != nil && err != errNeedExclusiveLock) { + return p, lock, upserted, err + } + + // Try again while asking for an exlusive lock + p, lock, upserted, err = lm.getPolicyCommon(storage, name, true, derived, exclusive) + if err != nil || p == nil || lock == nil { + return p, lock, upserted, err + } + + lock.Unlock() + + return lm.getPolicyCommon(storage, name, true, derived, shared) } // When the function returns, a lock will be held on the policy if err == nil. -// The type of lock will be indicated by the return value. It is the caller's -// responsibility to unlock. -func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, upsert, derived bool) (p *Policy, lockType bool, upserted bool, err error) { - // If we are using a cache, lock it now to avoid having to do really - // complicated lock juggling as we call various functions. We'll also defer - // the store into the cache. - lockType = shared - lm.lockPolicy(name, shared) +// It is the caller's responsibility to unlock. +func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, upsert, derived, lockType bool) (*Policy, *sync.RWMutex, bool, error) { + lock := lm.policyLock(name, lockType) + var p *Policy + var err error + + // Check if it's in our cache. If so, return right away. if lm.CacheActive() { - lm.globalMutex.RLock() + lm.cacheMutex.RLock() p = lm.cache[name] if p != nil { - lm.globalMutex.RUnlock() - return + lm.cacheMutex.RUnlock() + return p, lock, false, nil } - lm.globalMutex.RUnlock() - - // When we return, since we didn't have the policy in the cache, if - // there was no error, write the value in. - defer func() { - lm.globalMutex.Lock() - defer lm.globalMutex.Unlock() - // Make sure a policy didn't appear. If so, it will only be set if - // there was no error, so now just clear the error and return that - // policy. - exp := lm.cache[name] - if exp != nil { - upserted = false - err = nil - p = exp - return - } - - if err == nil { - lm.cache[name] = p - } - }() + lm.cacheMutex.RUnlock() } + // Load it from storage p, err = lm.getStoredPolicy(storage, name) if err != nil { - lm.UnlockPolicy(name, shared) - return + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, err } if p == nil { + // This is the only place we upsert a new policy, so if upsert is not + // specified, or the lock type is wrong, unllock before returning if !upsert { - lm.UnlockPolicy(name, shared) - return + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, nil } - // Get an exlusive lock; on success, check again to ensure that no - // policy exists. Note that if we are using a cache we will already be - // serializing this entire code path and it's currently the only one - // that generates policies, so we don't need to check the cache here; - // simply checking the disk again is sufficient. - lm.UnlockPolicy(name, shared) - lockType = exclusive - lm.lockPolicy(name, exclusive) - - p, err = lm.getStoredPolicy(storage, name) - if err != nil { - defer lm.UnlockPolicy(name, exclusive) - return + if lockType != exclusive { + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, errNeedExclusiveLock } - if p != nil { - return - } - - upserted = true p = &Policy{ Name: name, @@ -201,58 +200,75 @@ func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, ups err = p.rotate(storage) if err != nil { - defer lm.UnlockPolicy(name, exclusive) - p = nil + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, err + } + + if lm.CacheActive() { + // Since we didn't have the policy in the cache, if there was no + // error, write the value in. + lm.cacheMutex.Lock() + defer lm.cacheMutex.Unlock() + // Make sure a policy didn't appear. If so, it will only be set if + // there was no error, so assume it's good and return that + exp := lm.cache[name] + if exp != nil { + return exp, lock, false, nil + } + if err == nil { + lm.cache[name] = p + } } // We don't need to worry about upgrading since it will be a new policy - return + return p, lock, true, nil } if p.needsUpgrade() { - lm.UnlockPolicy(name, shared) - lockType = exclusive - lm.lockPolicy(name, exclusive) - - // Reload the policy with the write lock to ensure we still need the upgrade - p, err = lm.getStoredPolicy(storage, name) - if err != nil { - lm.UnlockPolicy(name, exclusive) - return - } - if p == nil { - err = fmt.Errorf("error reloading policy for upgrade") - lm.UnlockPolicy(name, exclusive) - return - } - - if !p.needsUpgrade() { - // Already happened, return the newly loaded policy - return + if lockType == shared { + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, errNeedExclusiveLock } err = p.upgrade(storage) if err != nil { - defer lm.UnlockPolicy(name, exclusive) + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, err } } - return + if lm.CacheActive() { + // Since we didn't have the policy in the cache, if there was no + // error, write the value in. + lm.cacheMutex.Lock() + defer lm.cacheMutex.Unlock() + // Make sure a policy didn't appear. If so, it will only be set if + // there was no error, so assume it's good and return that + exp := lm.cache[name] + if exp != nil { + return exp, lock, false, nil + } + if err == nil { + lm.cache[name] = p + } + } + + return p, lock, false, nil } func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error { - lm.lockAll(name) - defer lm.UnlockAll(name) + lm.cacheMutex.Lock() + lock := lm.policyLock(name, exclusive) + defer lock.Unlock() + defer lm.cacheMutex.Unlock() var p *Policy var err error if lm.CacheActive() { p = lm.cache[name] - if p == nil { - return fmt.Errorf("could not delete policy; not found") - } - } else { + } + if p == nil { p, err = lm.getStoredPolicy(storage, name) if err != nil { return err @@ -283,42 +299,6 @@ func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error return nil } -// When this function returns it's the responsibility of the caller to call -// UnlockPolicy if err is nil and policy is not nil -func (lm *lockManager) RefreshPolicy(storage logical.Storage, name string) (p *Policy, err error) { - lm.lockPolicy(name, exclusive) - - if lm.CacheActive() { - p = lm.cache[name] - if p != nil { - return - } - err = fmt.Errorf("could not refresh policy; not found") - defer lm.UnlockPolicy(name, exclusive) - return - } - - p, err = lm.getStoredPolicy(storage, name) - if err != nil { - lm.UnlockPolicy(name, exclusive) - return - } - - if p == nil { - err = fmt.Errorf("could not refresh policy; not found") - defer lm.UnlockPolicy(name, exclusive) - } - - if p.needsUpgrade() { - err = p.upgrade(storage) - if err != nil { - defer lm.UnlockPolicy(name, exclusive) - } - } - - return -} - func (lm *lockManager) getStoredPolicy(storage logical.Storage, name string) (*Policy, error) { // Check if the policy already exists raw, err := storage.Get("policy/" + name) diff --git a/builtin/logical/transit/path_config.go b/builtin/logical/transit/path_config.go index eabb3603f6..5dc84c3f76 100644 --- a/builtin/logical/transit/path_config.go +++ b/builtin/logical/transit/path_config.go @@ -42,7 +42,10 @@ func (b *backend) pathConfigWrite( name := d.Get("name").(string) // Check if the policy already exists before we lock everything - p, lockType, err := b.lm.GetPolicy(req.Storage, name) + p, lock, err := b.lm.GetPolicyExclusive(req.Storage, name) + if lock != nil { + defer lock.Unlock() + } if err != nil { return nil, err } @@ -52,31 +55,8 @@ func (b *backend) pathConfigWrite( logical.ErrInvalidRequest } - // Store some values so we can detect a change when we lock everything - currDeletionAllowed := p.DeletionAllowed - currMinDecryptionVersion := p.MinDecryptionVersion - - b.lm.UnlockPolicy(name, lockType) - - // Refresh in case it's changed since before we grabbed the lock - p, err = b.lm.RefreshPolicy(req.Storage, name) - if err != nil { - return nil, err - } - if p == nil { - return nil, fmt.Errorf("error finding key %s after locking for changes", name) - } - defer b.lm.UnlockPolicy(name, exclusive) - resp := &logical.Response{} - // Check for anything to have been updated since we got the write lock - if currDeletionAllowed != p.DeletionAllowed || - currMinDecryptionVersion != p.MinDecryptionVersion { - resp.AddWarning("key configuration has changed since this endpoint was called, not updating") - return resp, nil - } - persistNeeded := false minDecryptionVersionRaw, ok := d.GetOk("min_decryption_version") diff --git a/builtin/logical/transit/path_datakey.go b/builtin/logical/transit/path_datakey.go index bd6fca6517..817529e9dd 100644 --- a/builtin/logical/transit/path_datakey.go +++ b/builtin/logical/transit/path_datakey.go @@ -73,14 +73,16 @@ func (b *backend) pathDatakeyWrite( } // Get the policy - p, lockType, err := b.lm.GetPolicy(req.Storage, name) + p, lock, err := b.lm.GetPolicyShared(req.Storage, name) + if lock != nil { + defer lock.RUnlock() + } if err != nil { return nil, err } if p == nil { return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - defer b.lm.UnlockPolicy(name, lockType) newKey := make([]byte, 32) bits := d.Get("bits").(int) diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go index 8c5799331f..254b2b5093 100644 --- a/builtin/logical/transit/path_decrypt.go +++ b/builtin/logical/transit/path_decrypt.go @@ -58,7 +58,10 @@ func (b *backend) pathDecryptWrite( } // Get the policy - p, lockType, err := b.lm.GetPolicy(req.Storage, name) + p, lock, err := b.lm.GetPolicyShared(req.Storage, name) + if lock != nil { + defer lock.RUnlock() + } if err != nil { return nil, err } @@ -66,8 +69,6 @@ func (b *backend) pathDecryptWrite( return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - defer b.lm.UnlockPolicy(name, lockType) - plaintext, err := p.Decrypt(context, ciphertext) if err != nil { switch err.(type) { diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index bba7667972..3494faac81 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -3,6 +3,7 @@ package transit import ( "encoding/base64" "fmt" + "sync" "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/logical" @@ -44,13 +45,13 @@ func (b *backend) pathEncrypt() *framework.Path { func (b *backend) pathEncryptExistenceCheck( req *logical.Request, d *framework.FieldData) (bool, error) { name := d.Get("name").(string) - p, lockType, err := b.lm.GetPolicy(req.Storage, name) + p, lock, err := b.lm.GetPolicyShared(req.Storage, name) + if lock != nil { + defer lock.RUnlock() + } if err != nil { return false, err } - if p != nil { - defer b.lm.UnlockPolicy(name, lockType) - } return p != nil, nil } @@ -75,12 +76,15 @@ func (b *backend) pathEncryptWrite( // Get the policy var p *Policy - var lockType bool + var lock *sync.RWMutex var upserted bool if req.Operation == logical.CreateOperation { - p, lockType, upserted, err = b.lm.GetPolicyUpsert(req.Storage, name, len(context) != 0) + p, lock, upserted, err = b.lm.GetPolicyUpsert(req.Storage, name, len(context) != 0) } else { - p, lockType, err = b.lm.GetPolicy(req.Storage, name) + p, lock, err = b.lm.GetPolicyShared(req.Storage, name) + } + if lock != nil { + defer lock.RUnlock() } if err != nil { return nil, err @@ -88,7 +92,6 @@ func (b *backend) pathEncryptWrite( if p == nil { return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - defer b.lm.UnlockPolicy(name, lockType) ciphertext, err := p.Encrypt(context, value) if err != nil { diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index 956b295d15..466ca9b508 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -39,7 +39,10 @@ func (b *backend) pathPolicyWrite( name := d.Get("name").(string) derived := d.Get("derived").(bool) - p, lockType, upserted, err := b.lm.GetPolicyUpsert(req.Storage, name, derived) + p, lock, upserted, err := b.lm.GetPolicyUpsert(req.Storage, name, derived) + if lock != nil { + defer lock.RUnlock() + } if err != nil { return nil, err } @@ -47,8 +50,6 @@ func (b *backend) pathPolicyWrite( return nil, fmt.Errorf("error generating key: returned policy was nil") } - defer b.lm.UnlockPolicy(name, lockType) - resp := &logical.Response{} if !upserted { resp.AddWarning(fmt.Sprintf("key %s already existed", name)) @@ -61,7 +62,10 @@ func (b *backend) pathPolicyRead( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - p, lockType, err := b.lm.GetPolicy(req.Storage, name) + p, lock, err := b.lm.GetPolicyShared(req.Storage, name) + if lock != nil { + defer lock.RUnlock() + } if err != nil { return nil, err } @@ -69,8 +73,6 @@ func (b *backend) pathPolicyRead( return nil, nil } - defer b.lm.UnlockPolicy(name, lockType) - // Return the response resp := &logical.Response{ Data: map[string]interface{}{ @@ -99,17 +101,8 @@ func (b *backend) pathPolicyDelete( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - // Some sanity checking before we lock it all in the DeletePolicy method - p, lockType, err := b.lm.GetPolicy(req.Storage, name) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("error looking up policy %s, error is %s", name, err)), err - } - if p == nil { - return logical.ErrorResponse(fmt.Sprintf("no such key %s", name)), logical.ErrInvalidRequest - } - b.lm.UnlockPolicy(name, lockType) - - err = b.lm.DeletePolicy(req.Storage, name) + // Delete does its own locking + err := b.lm.DeletePolicy(req.Storage, name) if err != nil { return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err } diff --git a/builtin/logical/transit/path_rewrap.go b/builtin/logical/transit/path_rewrap.go index 61e3f607e8..a5854feeea 100644 --- a/builtin/logical/transit/path_rewrap.go +++ b/builtin/logical/transit/path_rewrap.go @@ -59,7 +59,10 @@ func (b *backend) pathRewrapWrite( } // Get the policy - p, lockType, err := b.lm.GetPolicy(req.Storage, name) + p, lock, err := b.lm.GetPolicyShared(req.Storage, name) + if lock != nil { + defer lock.RUnlock() + } if err != nil { return nil, err } @@ -68,8 +71,6 @@ func (b *backend) pathRewrapWrite( return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - defer b.lm.UnlockPolicy(name, lockType) - plaintext, err := p.Decrypt(context, value) if err != nil { switch err.(type) { diff --git a/builtin/logical/transit/path_rotate.go b/builtin/logical/transit/path_rotate.go index 442646ef8c..f10b78d56a 100644 --- a/builtin/logical/transit/path_rotate.go +++ b/builtin/logical/transit/path_rotate.go @@ -1,8 +1,6 @@ package transit import ( - "fmt" - "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -31,7 +29,10 @@ func (b *backend) pathRotateWrite( name := d.Get("name").(string) // Get the policy - p, lockType, err := b.lm.GetPolicy(req.Storage, name) + p, lock, err := b.lm.GetPolicyExclusive(req.Storage, name) + if lock != nil { + defer lock.Unlock() + } if err != nil { return nil, err } @@ -39,28 +40,6 @@ func (b *backend) pathRotateWrite( return logical.ErrorResponse("key not found"), logical.ErrInvalidRequest } - // Store so we can detect later if this has changed out from under us - keyVersion := p.LatestVersion - - b.lm.UnlockPolicy(name, lockType) - - // Refresh in case it's changed since before we grabbed the lock - p, err = b.lm.RefreshPolicy(req.Storage, name) - if err != nil { - return nil, err - } - if p == nil { - return nil, fmt.Errorf("error finding key %s after locking for changes", name) - } - defer b.lm.UnlockPolicy(name, exclusive) - - // Make sure that the policy hasn't been rotated simultaneously - if keyVersion != p.LatestVersion { - resp := &logical.Response{} - resp.AddWarning("key has been rotated since this endpoint was called; did not perform rotation") - return resp, nil - } - // Rotate the policy err = p.rotate(req.Storage) From e48cb2e840a1ba1e840834dccceb2f3e827a56d2 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Tue, 3 May 2016 00:19:18 -0400 Subject: [PATCH 11/11] Add some more tests around deletion and fix upsert status returning --- builtin/logical/transit/lock_manager.go | 11 ++-- builtin/logical/transit/policy_test.go | 83 ++++++++++++++++++++----- 2 files changed, 74 insertions(+), 20 deletions(-) diff --git a/builtin/logical/transit/lock_manager.go b/builtin/logical/transit/lock_manager.go index 9f20a423a5..515515726f 100644 --- a/builtin/logical/transit/lock_manager.go +++ b/builtin/logical/transit/lock_manager.go @@ -133,21 +133,24 @@ func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string) // needed, retry. If successful, call one more time to get a read lock and // return the value. func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, derived bool) (*Policy, *sync.RWMutex, bool, error) { - p, lock, upserted, err := lm.getPolicyCommon(storage, name, true, derived, shared) + p, lock, _, err := lm.getPolicyCommon(storage, name, true, derived, shared) if err == nil || (err != nil && err != errNeedExclusiveLock) { - return p, lock, upserted, err + return p, lock, false, err } // Try again while asking for an exlusive lock - p, lock, upserted, err = lm.getPolicyCommon(storage, name, true, derived, exclusive) + p, lock, upserted, err := lm.getPolicyCommon(storage, name, true, derived, exclusive) if err != nil || p == nil || lock == nil { return p, lock, upserted, err } lock.Unlock() - return lm.getPolicyCommon(storage, name, true, derived, shared) + // Now get a shared lock for the return, but preserve the value of upsert + p, lock, _, err = lm.getPolicyCommon(storage, name, true, derived, shared) + + return p, lock, upserted, err } // When the function returns, a lock will be held on the policy if err == nil. diff --git a/builtin/logical/transit/policy_test.go b/builtin/logical/transit/policy_test.go index b6d0acf418..f326416208 100644 --- a/builtin/logical/transit/policy_test.go +++ b/builtin/logical/transit/policy_test.go @@ -22,21 +22,19 @@ func Test_KeyUpgrade(t *testing.T) { func testKeyUpgradeCommon(t *testing.T, lm *lockManager) { storage := &logical.InmemStorage{} - p, lockType, upserted, err := lm.GetPolicyUpsert(storage, "test", false) + p, lock, upserted, err := lm.GetPolicyUpsert(storage, "test", false) + if lock != nil { + defer lock.RUnlock() + } if err != nil { t.Fatal(err) } if p == nil { t.Fatal("nil policy") } - defer lm.UnlockPolicy("test", lockType) - if !upserted { t.Fatal("expected an upsert") } - if lockType != exclusive { - t.Fatal("expected an exclusive lock") - } testBytes := make([]byte, len(p.Keys[1].Key)) copy(testBytes, p.Keys[1].Key) @@ -70,14 +68,14 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) { storage := &logical.InmemStorage{} - p, lockType, _, err := lm.GetPolicyUpsert(storage, "test", false) + p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false) if err != nil { t.Fatal(err) } - if p == nil { - t.Fatal("nil policy") + if p == nil || lock == nil { + t.Fatal("nil policy or lock") } - lm.UnlockPolicy("test", lockType) + lock.RUnlock() // Store the initial key in the archive keysArchive = append(keysArchive, p.Keys[1]) @@ -122,16 +120,67 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) { } // Now get the policy again; the upgrade should happen automatically - p, lockType, err = lm.GetPolicy(storage, "test") + p, lock, err = lm.GetPolicyShared(storage, "test") if err != nil { t.Fatal(err) } - if p == nil { - t.Fatal("nil lockingPolicy") + if p == nil || lock == nil { + t.Fatal("nil policy or lock") } - lm.UnlockPolicy("test", lockType) + lock.RUnlock() checkKeys(t, p, storage, "upgrade", 10, 10, 10) + + // Let's check some deletion logic while we're at it + + // The policy should be in there + if lm.CacheActive() && lm.cache["test"] == nil { + t.Fatal("nil policy in cache") + } + + // First we'll do this wrong, by not setting the deletion flag + err = lm.DeletePolicy(storage, "test") + if err == nil { + t.Fatal("got nil error, but should not have been able to delete since we didn't set the deletion flag on the policy") + } + + // The policy should still be in there + if lm.CacheActive() && lm.cache["test"] == nil { + t.Fatal("nil policy in cache") + } + + p, lock, err = lm.GetPolicyShared(storage, "test") + if err != nil { + t.Fatal(err) + } + if p == nil || lock == nil { + t.Fatal("policy or lock nil after bad delete") + } + lock.RUnlock() + + // Now do it properly + p.DeletionAllowed = true + err = p.Persist(storage) + if err != nil { + t.Fatal(err) + } + err = lm.DeletePolicy(storage, "test") + if err != nil { + t.Fatal(err) + } + + // The policy should *not* be in there + if lm.CacheActive() && lm.cache["test"] != nil { + t.Fatal("non-nil policy in cache") + } + + p, lock, err = lm.GetPolicyShared(storage, "test") + if err != nil { + t.Fatal(err) + } + if p != nil || lock != nil { + t.Fatal("policy or lock not nil after delete") + } } func Test_Archiving(t *testing.T) { @@ -149,14 +198,16 @@ func testArchivingCommon(t *testing.T, lm *lockManager) { storage := &logical.InmemStorage{} - p, lockType, _, err := lm.GetPolicyUpsert(storage, "test", false) + p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false) + if lock != nil { + defer lock.RUnlock() + } if err != nil { t.Fatal(err) } if p == nil { t.Fatal("nil policy") } - defer lm.UnlockPolicy("test", lockType) // Store the initial key in the archive keysArchive = append(keysArchive, p.Keys[1])