From 027d570f7f07ec51d1ae92604d5b24971dcf15ae Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Mon, 2 May 2016 23:46:39 -0400 Subject: [PATCH] Massively simplify lock handling based on feedback --- builtin/logical/transit/lock_manager.go | 278 +++++++++++------------- builtin/logical/transit/path_config.go | 28 +-- builtin/logical/transit/path_datakey.go | 6 +- builtin/logical/transit/path_decrypt.go | 7 +- builtin/logical/transit/path_encrypt.go | 19 +- builtin/logical/transit/path_keys.go | 27 +-- builtin/logical/transit/path_rewrap.go | 7 +- builtin/logical/transit/path_rotate.go | 29 +-- 8 files changed, 170 insertions(+), 231 deletions(-) diff --git a/builtin/logical/transit/lock_manager.go b/builtin/logical/transit/lock_manager.go index bd817ed3e0..9f20a423a5 100644 --- a/builtin/logical/transit/lock_manager.go +++ b/builtin/logical/transit/lock_manager.go @@ -2,6 +2,7 @@ package transit import ( "encoding/json" + "errors" "fmt" "sync" @@ -13,6 +14,10 @@ const ( exclusive = true ) +var ( + errNeedExclusiveLock = errors.New("an exclusive lock is needed for this operation") +) + type lockManager struct { // A lock for each named key locks map[string]*sync.RWMutex @@ -24,7 +29,7 @@ type lockManager struct { cache map[string]*Policy // Used for global locking, and as the cache map mutex - globalMutex sync.RWMutex + cacheMutex sync.RWMutex } func newLockManager(cacheDisabled bool) *lockManager { @@ -41,17 +46,7 @@ func (lm *lockManager) CacheActive() bool { return lm.cache != nil } -func (lm *lockManager) lockAll(name string) { - lm.globalMutex.Lock() - lm.lockPolicy(name, exclusive) -} - -func (lm *lockManager) UnlockAll(name string) { - lm.UnlockPolicy(name, exclusive) - lm.globalMutex.Unlock() -} - -func (lm *lockManager) lockPolicy(name string, writeLock bool) { +func (lm *lockManager) policyLock(name string, lockType bool) *sync.RWMutex { lm.locksMutex.RLock() lock := lm.locks[name] if lock != nil { @@ -59,12 +54,12 @@ func (lm *lockManager) lockPolicy(name string, writeLock bool) { // the only time we ever write to a value in this map is the first time // we access the value, so it won't be changing out from under us lm.locksMutex.RUnlock() - if writeLock { + if lockType == exclusive { lock.Lock() } else { lock.RLock() } - return + return lock } lm.locksMutex.RUnlock() @@ -78,117 +73,121 @@ func (lm *lockManager) lockPolicy(name string, writeLock bool) { lock = lm.locks[name] if lock != nil { lm.locksMutex.Unlock() - if writeLock { + if lockType == exclusive { lock.Lock() } else { lock.RLock() } - return + return lock } lock = &sync.RWMutex{} lm.locks[name] = lock lm.locksMutex.Unlock() - if writeLock { + if lockType == exclusive { lock.Lock() } else { lock.RLock() } + + return lock } -func (lm *lockManager) UnlockPolicy(name string, writeLock bool) { - lm.locksMutex.RLock() - lock := lm.locks[name] - lm.locksMutex.RUnlock() - - if writeLock { +func (lm *lockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) { + if lockType == exclusive { lock.Unlock() } else { lock.RUnlock() } } -func (lm *lockManager) GetPolicy(storage logical.Storage, name string) (*Policy, bool, error) { - p, lt, _, err := lm.getPolicyCommon(storage, name, false, false) - return p, lt, err +// Get the policy with a read lock. If we get an error saying an exclusive lock +// is needed (for instance, for an upgrade/migration), give up the read lock, +// call again with an exclusive lock, then swap back out for a read lock. +func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) { + p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, shared) + if err == nil || + (err != nil && err != errNeedExclusiveLock) { + return p, lock, err + } + + // Try again while asking for an exlusive lock + p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, exclusive) + if err != nil || p == nil || lock == nil { + return p, lock, err + } + + lock.Unlock() + + p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, shared) + return p, lock, err } -func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, derived bool) (*Policy, bool, bool, error) { - return lm.getPolicyCommon(storage, name, true, derived) +// Get the policy with an exclusive lock +func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) { + p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, exclusive) + return p, lock, err +} + +// Get the policy with a read lock; if it returns that an exclusive lock is +// needed, retry. If successful, call one more time to get a read lock and +// return the value. +func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, derived bool) (*Policy, *sync.RWMutex, bool, error) { + p, lock, upserted, err := lm.getPolicyCommon(storage, name, true, derived, shared) + if err == nil || + (err != nil && err != errNeedExclusiveLock) { + return p, lock, upserted, err + } + + // Try again while asking for an exlusive lock + p, lock, upserted, err = lm.getPolicyCommon(storage, name, true, derived, exclusive) + if err != nil || p == nil || lock == nil { + return p, lock, upserted, err + } + + lock.Unlock() + + return lm.getPolicyCommon(storage, name, true, derived, shared) } // When the function returns, a lock will be held on the policy if err == nil. -// The type of lock will be indicated by the return value. It is the caller's -// responsibility to unlock. -func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, upsert, derived bool) (p *Policy, lockType bool, upserted bool, err error) { - // If we are using a cache, lock it now to avoid having to do really - // complicated lock juggling as we call various functions. We'll also defer - // the store into the cache. - lockType = shared - lm.lockPolicy(name, shared) +// It is the caller's responsibility to unlock. +func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, upsert, derived, lockType bool) (*Policy, *sync.RWMutex, bool, error) { + lock := lm.policyLock(name, lockType) + var p *Policy + var err error + + // Check if it's in our cache. If so, return right away. if lm.CacheActive() { - lm.globalMutex.RLock() + lm.cacheMutex.RLock() p = lm.cache[name] if p != nil { - lm.globalMutex.RUnlock() - return + lm.cacheMutex.RUnlock() + return p, lock, false, nil } - lm.globalMutex.RUnlock() - - // When we return, since we didn't have the policy in the cache, if - // there was no error, write the value in. - defer func() { - lm.globalMutex.Lock() - defer lm.globalMutex.Unlock() - // Make sure a policy didn't appear. If so, it will only be set if - // there was no error, so now just clear the error and return that - // policy. - exp := lm.cache[name] - if exp != nil { - upserted = false - err = nil - p = exp - return - } - - if err == nil { - lm.cache[name] = p - } - }() + lm.cacheMutex.RUnlock() } + // Load it from storage p, err = lm.getStoredPolicy(storage, name) if err != nil { - lm.UnlockPolicy(name, shared) - return + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, err } if p == nil { + // This is the only place we upsert a new policy, so if upsert is not + // specified, or the lock type is wrong, unllock before returning if !upsert { - lm.UnlockPolicy(name, shared) - return + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, nil } - // Get an exlusive lock; on success, check again to ensure that no - // policy exists. Note that if we are using a cache we will already be - // serializing this entire code path and it's currently the only one - // that generates policies, so we don't need to check the cache here; - // simply checking the disk again is sufficient. - lm.UnlockPolicy(name, shared) - lockType = exclusive - lm.lockPolicy(name, exclusive) - - p, err = lm.getStoredPolicy(storage, name) - if err != nil { - defer lm.UnlockPolicy(name, exclusive) - return + if lockType != exclusive { + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, errNeedExclusiveLock } - if p != nil { - return - } - - upserted = true p = &Policy{ Name: name, @@ -201,58 +200,75 @@ func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, ups err = p.rotate(storage) if err != nil { - defer lm.UnlockPolicy(name, exclusive) - p = nil + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, err + } + + if lm.CacheActive() { + // Since we didn't have the policy in the cache, if there was no + // error, write the value in. + lm.cacheMutex.Lock() + defer lm.cacheMutex.Unlock() + // Make sure a policy didn't appear. If so, it will only be set if + // there was no error, so assume it's good and return that + exp := lm.cache[name] + if exp != nil { + return exp, lock, false, nil + } + if err == nil { + lm.cache[name] = p + } } // We don't need to worry about upgrading since it will be a new policy - return + return p, lock, true, nil } if p.needsUpgrade() { - lm.UnlockPolicy(name, shared) - lockType = exclusive - lm.lockPolicy(name, exclusive) - - // Reload the policy with the write lock to ensure we still need the upgrade - p, err = lm.getStoredPolicy(storage, name) - if err != nil { - lm.UnlockPolicy(name, exclusive) - return - } - if p == nil { - err = fmt.Errorf("error reloading policy for upgrade") - lm.UnlockPolicy(name, exclusive) - return - } - - if !p.needsUpgrade() { - // Already happened, return the newly loaded policy - return + if lockType == shared { + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, errNeedExclusiveLock } err = p.upgrade(storage) if err != nil { - defer lm.UnlockPolicy(name, exclusive) + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, err } } - return + if lm.CacheActive() { + // Since we didn't have the policy in the cache, if there was no + // error, write the value in. + lm.cacheMutex.Lock() + defer lm.cacheMutex.Unlock() + // Make sure a policy didn't appear. If so, it will only be set if + // there was no error, so assume it's good and return that + exp := lm.cache[name] + if exp != nil { + return exp, lock, false, nil + } + if err == nil { + lm.cache[name] = p + } + } + + return p, lock, false, nil } func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error { - lm.lockAll(name) - defer lm.UnlockAll(name) + lm.cacheMutex.Lock() + lock := lm.policyLock(name, exclusive) + defer lock.Unlock() + defer lm.cacheMutex.Unlock() var p *Policy var err error if lm.CacheActive() { p = lm.cache[name] - if p == nil { - return fmt.Errorf("could not delete policy; not found") - } - } else { + } + if p == nil { p, err = lm.getStoredPolicy(storage, name) if err != nil { return err @@ -283,42 +299,6 @@ func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error return nil } -// When this function returns it's the responsibility of the caller to call -// UnlockPolicy if err is nil and policy is not nil -func (lm *lockManager) RefreshPolicy(storage logical.Storage, name string) (p *Policy, err error) { - lm.lockPolicy(name, exclusive) - - if lm.CacheActive() { - p = lm.cache[name] - if p != nil { - return - } - err = fmt.Errorf("could not refresh policy; not found") - defer lm.UnlockPolicy(name, exclusive) - return - } - - p, err = lm.getStoredPolicy(storage, name) - if err != nil { - lm.UnlockPolicy(name, exclusive) - return - } - - if p == nil { - err = fmt.Errorf("could not refresh policy; not found") - defer lm.UnlockPolicy(name, exclusive) - } - - if p.needsUpgrade() { - err = p.upgrade(storage) - if err != nil { - defer lm.UnlockPolicy(name, exclusive) - } - } - - return -} - func (lm *lockManager) getStoredPolicy(storage logical.Storage, name string) (*Policy, error) { // Check if the policy already exists raw, err := storage.Get("policy/" + name) diff --git a/builtin/logical/transit/path_config.go b/builtin/logical/transit/path_config.go index eabb3603f6..5dc84c3f76 100644 --- a/builtin/logical/transit/path_config.go +++ b/builtin/logical/transit/path_config.go @@ -42,7 +42,10 @@ func (b *backend) pathConfigWrite( name := d.Get("name").(string) // Check if the policy already exists before we lock everything - p, lockType, err := b.lm.GetPolicy(req.Storage, name) + p, lock, err := b.lm.GetPolicyExclusive(req.Storage, name) + if lock != nil { + defer lock.Unlock() + } if err != nil { return nil, err } @@ -52,31 +55,8 @@ func (b *backend) pathConfigWrite( logical.ErrInvalidRequest } - // Store some values so we can detect a change when we lock everything - currDeletionAllowed := p.DeletionAllowed - currMinDecryptionVersion := p.MinDecryptionVersion - - b.lm.UnlockPolicy(name, lockType) - - // Refresh in case it's changed since before we grabbed the lock - p, err = b.lm.RefreshPolicy(req.Storage, name) - if err != nil { - return nil, err - } - if p == nil { - return nil, fmt.Errorf("error finding key %s after locking for changes", name) - } - defer b.lm.UnlockPolicy(name, exclusive) - resp := &logical.Response{} - // Check for anything to have been updated since we got the write lock - if currDeletionAllowed != p.DeletionAllowed || - currMinDecryptionVersion != p.MinDecryptionVersion { - resp.AddWarning("key configuration has changed since this endpoint was called, not updating") - return resp, nil - } - persistNeeded := false minDecryptionVersionRaw, ok := d.GetOk("min_decryption_version") diff --git a/builtin/logical/transit/path_datakey.go b/builtin/logical/transit/path_datakey.go index bd6fca6517..817529e9dd 100644 --- a/builtin/logical/transit/path_datakey.go +++ b/builtin/logical/transit/path_datakey.go @@ -73,14 +73,16 @@ func (b *backend) pathDatakeyWrite( } // Get the policy - p, lockType, err := b.lm.GetPolicy(req.Storage, name) + p, lock, err := b.lm.GetPolicyShared(req.Storage, name) + if lock != nil { + defer lock.RUnlock() + } if err != nil { return nil, err } if p == nil { return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - defer b.lm.UnlockPolicy(name, lockType) newKey := make([]byte, 32) bits := d.Get("bits").(int) diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go index 8c5799331f..254b2b5093 100644 --- a/builtin/logical/transit/path_decrypt.go +++ b/builtin/logical/transit/path_decrypt.go @@ -58,7 +58,10 @@ func (b *backend) pathDecryptWrite( } // Get the policy - p, lockType, err := b.lm.GetPolicy(req.Storage, name) + p, lock, err := b.lm.GetPolicyShared(req.Storage, name) + if lock != nil { + defer lock.RUnlock() + } if err != nil { return nil, err } @@ -66,8 +69,6 @@ func (b *backend) pathDecryptWrite( return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - defer b.lm.UnlockPolicy(name, lockType) - plaintext, err := p.Decrypt(context, ciphertext) if err != nil { switch err.(type) { diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index bba7667972..3494faac81 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -3,6 +3,7 @@ package transit import ( "encoding/base64" "fmt" + "sync" "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/logical" @@ -44,13 +45,13 @@ func (b *backend) pathEncrypt() *framework.Path { func (b *backend) pathEncryptExistenceCheck( req *logical.Request, d *framework.FieldData) (bool, error) { name := d.Get("name").(string) - p, lockType, err := b.lm.GetPolicy(req.Storage, name) + p, lock, err := b.lm.GetPolicyShared(req.Storage, name) + if lock != nil { + defer lock.RUnlock() + } if err != nil { return false, err } - if p != nil { - defer b.lm.UnlockPolicy(name, lockType) - } return p != nil, nil } @@ -75,12 +76,15 @@ func (b *backend) pathEncryptWrite( // Get the policy var p *Policy - var lockType bool + var lock *sync.RWMutex var upserted bool if req.Operation == logical.CreateOperation { - p, lockType, upserted, err = b.lm.GetPolicyUpsert(req.Storage, name, len(context) != 0) + p, lock, upserted, err = b.lm.GetPolicyUpsert(req.Storage, name, len(context) != 0) } else { - p, lockType, err = b.lm.GetPolicy(req.Storage, name) + p, lock, err = b.lm.GetPolicyShared(req.Storage, name) + } + if lock != nil { + defer lock.RUnlock() } if err != nil { return nil, err @@ -88,7 +92,6 @@ func (b *backend) pathEncryptWrite( if p == nil { return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - defer b.lm.UnlockPolicy(name, lockType) ciphertext, err := p.Encrypt(context, value) if err != nil { diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index 956b295d15..466ca9b508 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -39,7 +39,10 @@ func (b *backend) pathPolicyWrite( name := d.Get("name").(string) derived := d.Get("derived").(bool) - p, lockType, upserted, err := b.lm.GetPolicyUpsert(req.Storage, name, derived) + p, lock, upserted, err := b.lm.GetPolicyUpsert(req.Storage, name, derived) + if lock != nil { + defer lock.RUnlock() + } if err != nil { return nil, err } @@ -47,8 +50,6 @@ func (b *backend) pathPolicyWrite( return nil, fmt.Errorf("error generating key: returned policy was nil") } - defer b.lm.UnlockPolicy(name, lockType) - resp := &logical.Response{} if !upserted { resp.AddWarning(fmt.Sprintf("key %s already existed", name)) @@ -61,7 +62,10 @@ func (b *backend) pathPolicyRead( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - p, lockType, err := b.lm.GetPolicy(req.Storage, name) + p, lock, err := b.lm.GetPolicyShared(req.Storage, name) + if lock != nil { + defer lock.RUnlock() + } if err != nil { return nil, err } @@ -69,8 +73,6 @@ func (b *backend) pathPolicyRead( return nil, nil } - defer b.lm.UnlockPolicy(name, lockType) - // Return the response resp := &logical.Response{ Data: map[string]interface{}{ @@ -99,17 +101,8 @@ func (b *backend) pathPolicyDelete( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - // Some sanity checking before we lock it all in the DeletePolicy method - p, lockType, err := b.lm.GetPolicy(req.Storage, name) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("error looking up policy %s, error is %s", name, err)), err - } - if p == nil { - return logical.ErrorResponse(fmt.Sprintf("no such key %s", name)), logical.ErrInvalidRequest - } - b.lm.UnlockPolicy(name, lockType) - - err = b.lm.DeletePolicy(req.Storage, name) + // Delete does its own locking + err := b.lm.DeletePolicy(req.Storage, name) if err != nil { return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err } diff --git a/builtin/logical/transit/path_rewrap.go b/builtin/logical/transit/path_rewrap.go index 61e3f607e8..a5854feeea 100644 --- a/builtin/logical/transit/path_rewrap.go +++ b/builtin/logical/transit/path_rewrap.go @@ -59,7 +59,10 @@ func (b *backend) pathRewrapWrite( } // Get the policy - p, lockType, err := b.lm.GetPolicy(req.Storage, name) + p, lock, err := b.lm.GetPolicyShared(req.Storage, name) + if lock != nil { + defer lock.RUnlock() + } if err != nil { return nil, err } @@ -68,8 +71,6 @@ func (b *backend) pathRewrapWrite( return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - defer b.lm.UnlockPolicy(name, lockType) - plaintext, err := p.Decrypt(context, value) if err != nil { switch err.(type) { diff --git a/builtin/logical/transit/path_rotate.go b/builtin/logical/transit/path_rotate.go index 442646ef8c..f10b78d56a 100644 --- a/builtin/logical/transit/path_rotate.go +++ b/builtin/logical/transit/path_rotate.go @@ -1,8 +1,6 @@ package transit import ( - "fmt" - "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -31,7 +29,10 @@ func (b *backend) pathRotateWrite( name := d.Get("name").(string) // Get the policy - p, lockType, err := b.lm.GetPolicy(req.Storage, name) + p, lock, err := b.lm.GetPolicyExclusive(req.Storage, name) + if lock != nil { + defer lock.Unlock() + } if err != nil { return nil, err } @@ -39,28 +40,6 @@ func (b *backend) pathRotateWrite( return logical.ErrorResponse("key not found"), logical.ErrInvalidRequest } - // Store so we can detect later if this has changed out from under us - keyVersion := p.LatestVersion - - b.lm.UnlockPolicy(name, lockType) - - // Refresh in case it's changed since before we grabbed the lock - p, err = b.lm.RefreshPolicy(req.Storage, name) - if err != nil { - return nil, err - } - if p == nil { - return nil, fmt.Errorf("error finding key %s after locking for changes", name) - } - defer b.lm.UnlockPolicy(name, exclusive) - - // Make sure that the policy hasn't been rotated simultaneously - if keyVersion != p.LatestVersion { - resp := &logical.Response{} - resp.AddWarning("key has been rotated since this endpoint was called; did not perform rotation") - return resp, nil - } - // Rotate the policy err = p.rotate(req.Storage)