Massively simplify lock handling based on feedback

This commit is contained in:
Jeff Mitchell 2016-05-02 23:46:39 -04:00
parent bf7ad912e1
commit 027d570f7f
8 changed files with 170 additions and 231 deletions

View file

@ -2,6 +2,7 @@ package transit
import (
"encoding/json"
"errors"
"fmt"
"sync"
@ -13,6 +14,10 @@ const (
exclusive = true
)
var (
errNeedExclusiveLock = errors.New("an exclusive lock is needed for this operation")
)
type lockManager struct {
// A lock for each named key
locks map[string]*sync.RWMutex
@ -24,7 +29,7 @@ type lockManager struct {
cache map[string]*Policy
// Used for global locking, and as the cache map mutex
globalMutex sync.RWMutex
cacheMutex sync.RWMutex
}
func newLockManager(cacheDisabled bool) *lockManager {
@ -41,17 +46,7 @@ func (lm *lockManager) CacheActive() bool {
return lm.cache != nil
}
func (lm *lockManager) lockAll(name string) {
lm.globalMutex.Lock()
lm.lockPolicy(name, exclusive)
}
func (lm *lockManager) UnlockAll(name string) {
lm.UnlockPolicy(name, exclusive)
lm.globalMutex.Unlock()
}
func (lm *lockManager) lockPolicy(name string, writeLock bool) {
func (lm *lockManager) policyLock(name string, lockType bool) *sync.RWMutex {
lm.locksMutex.RLock()
lock := lm.locks[name]
if lock != nil {
@ -59,12 +54,12 @@ func (lm *lockManager) lockPolicy(name string, writeLock bool) {
// the only time we ever write to a value in this map is the first time
// we access the value, so it won't be changing out from under us
lm.locksMutex.RUnlock()
if writeLock {
if lockType == exclusive {
lock.Lock()
} else {
lock.RLock()
}
return
return lock
}
lm.locksMutex.RUnlock()
@ -78,117 +73,121 @@ func (lm *lockManager) lockPolicy(name string, writeLock bool) {
lock = lm.locks[name]
if lock != nil {
lm.locksMutex.Unlock()
if writeLock {
if lockType == exclusive {
lock.Lock()
} else {
lock.RLock()
}
return
return lock
}
lock = &sync.RWMutex{}
lm.locks[name] = lock
lm.locksMutex.Unlock()
if writeLock {
if lockType == exclusive {
lock.Lock()
} else {
lock.RLock()
}
return lock
}
func (lm *lockManager) UnlockPolicy(name string, writeLock bool) {
lm.locksMutex.RLock()
lock := lm.locks[name]
lm.locksMutex.RUnlock()
if writeLock {
func (lm *lockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) {
if lockType == exclusive {
lock.Unlock()
} else {
lock.RUnlock()
}
}
func (lm *lockManager) GetPolicy(storage logical.Storage, name string) (*Policy, bool, error) {
p, lt, _, err := lm.getPolicyCommon(storage, name, false, false)
return p, lt, err
// Get the policy with a read lock. If we get an error saying an exclusive lock
// is needed (for instance, for an upgrade/migration), give up the read lock,
// call again with an exclusive lock, then swap back out for a read lock.
func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, shared)
if err == nil ||
(err != nil && err != errNeedExclusiveLock) {
return p, lock, err
}
// Try again while asking for an exlusive lock
p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, exclusive)
if err != nil || p == nil || lock == nil {
return p, lock, err
}
lock.Unlock()
p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, shared)
return p, lock, err
}
func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, derived bool) (*Policy, bool, bool, error) {
return lm.getPolicyCommon(storage, name, true, derived)
// Get the policy with an exclusive lock
func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, exclusive)
return p, lock, err
}
// Get the policy with a read lock; if it returns that an exclusive lock is
// needed, retry. If successful, call one more time to get a read lock and
// return the value.
func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, derived bool) (*Policy, *sync.RWMutex, bool, error) {
p, lock, upserted, err := lm.getPolicyCommon(storage, name, true, derived, shared)
if err == nil ||
(err != nil && err != errNeedExclusiveLock) {
return p, lock, upserted, err
}
// Try again while asking for an exlusive lock
p, lock, upserted, err = lm.getPolicyCommon(storage, name, true, derived, exclusive)
if err != nil || p == nil || lock == nil {
return p, lock, upserted, err
}
lock.Unlock()
return lm.getPolicyCommon(storage, name, true, derived, shared)
}
// When the function returns, a lock will be held on the policy if err == nil.
// The type of lock will be indicated by the return value. It is the caller's
// responsibility to unlock.
func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, upsert, derived bool) (p *Policy, lockType bool, upserted bool, err error) {
// If we are using a cache, lock it now to avoid having to do really
// complicated lock juggling as we call various functions. We'll also defer
// the store into the cache.
lockType = shared
lm.lockPolicy(name, shared)
// It is the caller's responsibility to unlock.
func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, upsert, derived, lockType bool) (*Policy, *sync.RWMutex, bool, error) {
lock := lm.policyLock(name, lockType)
var p *Policy
var err error
// Check if it's in our cache. If so, return right away.
if lm.CacheActive() {
lm.globalMutex.RLock()
lm.cacheMutex.RLock()
p = lm.cache[name]
if p != nil {
lm.globalMutex.RUnlock()
return
lm.cacheMutex.RUnlock()
return p, lock, false, nil
}
lm.globalMutex.RUnlock()
// When we return, since we didn't have the policy in the cache, if
// there was no error, write the value in.
defer func() {
lm.globalMutex.Lock()
defer lm.globalMutex.Unlock()
// Make sure a policy didn't appear. If so, it will only be set if
// there was no error, so now just clear the error and return that
// policy.
exp := lm.cache[name]
if exp != nil {
upserted = false
err = nil
p = exp
return
}
if err == nil {
lm.cache[name] = p
}
}()
lm.cacheMutex.RUnlock()
}
// Load it from storage
p, err = lm.getStoredPolicy(storage, name)
if err != nil {
lm.UnlockPolicy(name, shared)
return
lm.UnlockPolicy(lock, lockType)
return nil, nil, false, err
}
if p == nil {
// This is the only place we upsert a new policy, so if upsert is not
// specified, or the lock type is wrong, unllock before returning
if !upsert {
lm.UnlockPolicy(name, shared)
return
lm.UnlockPolicy(lock, lockType)
return nil, nil, false, nil
}
// Get an exlusive lock; on success, check again to ensure that no
// policy exists. Note that if we are using a cache we will already be
// serializing this entire code path and it's currently the only one
// that generates policies, so we don't need to check the cache here;
// simply checking the disk again is sufficient.
lm.UnlockPolicy(name, shared)
lockType = exclusive
lm.lockPolicy(name, exclusive)
p, err = lm.getStoredPolicy(storage, name)
if err != nil {
defer lm.UnlockPolicy(name, exclusive)
return
if lockType != exclusive {
lm.UnlockPolicy(lock, lockType)
return nil, nil, false, errNeedExclusiveLock
}
if p != nil {
return
}
upserted = true
p = &Policy{
Name: name,
@ -201,58 +200,75 @@ func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, ups
err = p.rotate(storage)
if err != nil {
defer lm.UnlockPolicy(name, exclusive)
p = nil
lm.UnlockPolicy(lock, lockType)
return nil, nil, false, err
}
if lm.CacheActive() {
// Since we didn't have the policy in the cache, if there was no
// error, write the value in.
lm.cacheMutex.Lock()
defer lm.cacheMutex.Unlock()
// Make sure a policy didn't appear. If so, it will only be set if
// there was no error, so assume it's good and return that
exp := lm.cache[name]
if exp != nil {
return exp, lock, false, nil
}
if err == nil {
lm.cache[name] = p
}
}
// We don't need to worry about upgrading since it will be a new policy
return
return p, lock, true, nil
}
if p.needsUpgrade() {
lm.UnlockPolicy(name, shared)
lockType = exclusive
lm.lockPolicy(name, exclusive)
// Reload the policy with the write lock to ensure we still need the upgrade
p, err = lm.getStoredPolicy(storage, name)
if err != nil {
lm.UnlockPolicy(name, exclusive)
return
}
if p == nil {
err = fmt.Errorf("error reloading policy for upgrade")
lm.UnlockPolicy(name, exclusive)
return
}
if !p.needsUpgrade() {
// Already happened, return the newly loaded policy
return
if lockType == shared {
lm.UnlockPolicy(lock, lockType)
return nil, nil, false, errNeedExclusiveLock
}
err = p.upgrade(storage)
if err != nil {
defer lm.UnlockPolicy(name, exclusive)
lm.UnlockPolicy(lock, lockType)
return nil, nil, false, err
}
}
return
if lm.CacheActive() {
// Since we didn't have the policy in the cache, if there was no
// error, write the value in.
lm.cacheMutex.Lock()
defer lm.cacheMutex.Unlock()
// Make sure a policy didn't appear. If so, it will only be set if
// there was no error, so assume it's good and return that
exp := lm.cache[name]
if exp != nil {
return exp, lock, false, nil
}
if err == nil {
lm.cache[name] = p
}
}
return p, lock, false, nil
}
func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error {
lm.lockAll(name)
defer lm.UnlockAll(name)
lm.cacheMutex.Lock()
lock := lm.policyLock(name, exclusive)
defer lock.Unlock()
defer lm.cacheMutex.Unlock()
var p *Policy
var err error
if lm.CacheActive() {
p = lm.cache[name]
if p == nil {
return fmt.Errorf("could not delete policy; not found")
}
} else {
}
if p == nil {
p, err = lm.getStoredPolicy(storage, name)
if err != nil {
return err
@ -283,42 +299,6 @@ func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error
return nil
}
// When this function returns it's the responsibility of the caller to call
// UnlockPolicy if err is nil and policy is not nil
func (lm *lockManager) RefreshPolicy(storage logical.Storage, name string) (p *Policy, err error) {
lm.lockPolicy(name, exclusive)
if lm.CacheActive() {
p = lm.cache[name]
if p != nil {
return
}
err = fmt.Errorf("could not refresh policy; not found")
defer lm.UnlockPolicy(name, exclusive)
return
}
p, err = lm.getStoredPolicy(storage, name)
if err != nil {
lm.UnlockPolicy(name, exclusive)
return
}
if p == nil {
err = fmt.Errorf("could not refresh policy; not found")
defer lm.UnlockPolicy(name, exclusive)
}
if p.needsUpgrade() {
err = p.upgrade(storage)
if err != nil {
defer lm.UnlockPolicy(name, exclusive)
}
}
return
}
func (lm *lockManager) getStoredPolicy(storage logical.Storage, name string) (*Policy, error) {
// Check if the policy already exists
raw, err := storage.Get("policy/" + name)

View file

@ -42,7 +42,10 @@ func (b *backend) pathConfigWrite(
name := d.Get("name").(string)
// Check if the policy already exists before we lock everything
p, lockType, err := b.lm.GetPolicy(req.Storage, name)
p, lock, err := b.lm.GetPolicyExclusive(req.Storage, name)
if lock != nil {
defer lock.Unlock()
}
if err != nil {
return nil, err
}
@ -52,31 +55,8 @@ func (b *backend) pathConfigWrite(
logical.ErrInvalidRequest
}
// Store some values so we can detect a change when we lock everything
currDeletionAllowed := p.DeletionAllowed
currMinDecryptionVersion := p.MinDecryptionVersion
b.lm.UnlockPolicy(name, lockType)
// Refresh in case it's changed since before we grabbed the lock
p, err = b.lm.RefreshPolicy(req.Storage, name)
if err != nil {
return nil, err
}
if p == nil {
return nil, fmt.Errorf("error finding key %s after locking for changes", name)
}
defer b.lm.UnlockPolicy(name, exclusive)
resp := &logical.Response{}
// Check for anything to have been updated since we got the write lock
if currDeletionAllowed != p.DeletionAllowed ||
currMinDecryptionVersion != p.MinDecryptionVersion {
resp.AddWarning("key configuration has changed since this endpoint was called, not updating")
return resp, nil
}
persistNeeded := false
minDecryptionVersionRaw, ok := d.GetOk("min_decryption_version")

View file

@ -73,14 +73,16 @@ func (b *backend) pathDatakeyWrite(
}
// Get the policy
p, lockType, err := b.lm.GetPolicy(req.Storage, name)
p, lock, err := b.lm.GetPolicyShared(req.Storage, name)
if lock != nil {
defer lock.RUnlock()
}
if err != nil {
return nil, err
}
if p == nil {
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
}
defer b.lm.UnlockPolicy(name, lockType)
newKey := make([]byte, 32)
bits := d.Get("bits").(int)

View file

@ -58,7 +58,10 @@ func (b *backend) pathDecryptWrite(
}
// Get the policy
p, lockType, err := b.lm.GetPolicy(req.Storage, name)
p, lock, err := b.lm.GetPolicyShared(req.Storage, name)
if lock != nil {
defer lock.RUnlock()
}
if err != nil {
return nil, err
}
@ -66,8 +69,6 @@ func (b *backend) pathDecryptWrite(
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
}
defer b.lm.UnlockPolicy(name, lockType)
plaintext, err := p.Decrypt(context, ciphertext)
if err != nil {
switch err.(type) {

View file

@ -3,6 +3,7 @@ package transit
import (
"encoding/base64"
"fmt"
"sync"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/logical"
@ -44,13 +45,13 @@ func (b *backend) pathEncrypt() *framework.Path {
func (b *backend) pathEncryptExistenceCheck(
req *logical.Request, d *framework.FieldData) (bool, error) {
name := d.Get("name").(string)
p, lockType, err := b.lm.GetPolicy(req.Storage, name)
p, lock, err := b.lm.GetPolicyShared(req.Storage, name)
if lock != nil {
defer lock.RUnlock()
}
if err != nil {
return false, err
}
if p != nil {
defer b.lm.UnlockPolicy(name, lockType)
}
return p != nil, nil
}
@ -75,12 +76,15 @@ func (b *backend) pathEncryptWrite(
// Get the policy
var p *Policy
var lockType bool
var lock *sync.RWMutex
var upserted bool
if req.Operation == logical.CreateOperation {
p, lockType, upserted, err = b.lm.GetPolicyUpsert(req.Storage, name, len(context) != 0)
p, lock, upserted, err = b.lm.GetPolicyUpsert(req.Storage, name, len(context) != 0)
} else {
p, lockType, err = b.lm.GetPolicy(req.Storage, name)
p, lock, err = b.lm.GetPolicyShared(req.Storage, name)
}
if lock != nil {
defer lock.RUnlock()
}
if err != nil {
return nil, err
@ -88,7 +92,6 @@ func (b *backend) pathEncryptWrite(
if p == nil {
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
}
defer b.lm.UnlockPolicy(name, lockType)
ciphertext, err := p.Encrypt(context, value)
if err != nil {

View file

@ -39,7 +39,10 @@ func (b *backend) pathPolicyWrite(
name := d.Get("name").(string)
derived := d.Get("derived").(bool)
p, lockType, upserted, err := b.lm.GetPolicyUpsert(req.Storage, name, derived)
p, lock, upserted, err := b.lm.GetPolicyUpsert(req.Storage, name, derived)
if lock != nil {
defer lock.RUnlock()
}
if err != nil {
return nil, err
}
@ -47,8 +50,6 @@ func (b *backend) pathPolicyWrite(
return nil, fmt.Errorf("error generating key: returned policy was nil")
}
defer b.lm.UnlockPolicy(name, lockType)
resp := &logical.Response{}
if !upserted {
resp.AddWarning(fmt.Sprintf("key %s already existed", name))
@ -61,7 +62,10 @@ func (b *backend) pathPolicyRead(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
p, lockType, err := b.lm.GetPolicy(req.Storage, name)
p, lock, err := b.lm.GetPolicyShared(req.Storage, name)
if lock != nil {
defer lock.RUnlock()
}
if err != nil {
return nil, err
}
@ -69,8 +73,6 @@ func (b *backend) pathPolicyRead(
return nil, nil
}
defer b.lm.UnlockPolicy(name, lockType)
// Return the response
resp := &logical.Response{
Data: map[string]interface{}{
@ -99,17 +101,8 @@ func (b *backend) pathPolicyDelete(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
// Some sanity checking before we lock it all in the DeletePolicy method
p, lockType, err := b.lm.GetPolicy(req.Storage, name)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("error looking up policy %s, error is %s", name, err)), err
}
if p == nil {
return logical.ErrorResponse(fmt.Sprintf("no such key %s", name)), logical.ErrInvalidRequest
}
b.lm.UnlockPolicy(name, lockType)
err = b.lm.DeletePolicy(req.Storage, name)
// Delete does its own locking
err := b.lm.DeletePolicy(req.Storage, name)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err
}

View file

@ -59,7 +59,10 @@ func (b *backend) pathRewrapWrite(
}
// Get the policy
p, lockType, err := b.lm.GetPolicy(req.Storage, name)
p, lock, err := b.lm.GetPolicyShared(req.Storage, name)
if lock != nil {
defer lock.RUnlock()
}
if err != nil {
return nil, err
}
@ -68,8 +71,6 @@ func (b *backend) pathRewrapWrite(
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
}
defer b.lm.UnlockPolicy(name, lockType)
plaintext, err := p.Decrypt(context, value)
if err != nil {
switch err.(type) {

View file

@ -1,8 +1,6 @@
package transit
import (
"fmt"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -31,7 +29,10 @@ func (b *backend) pathRotateWrite(
name := d.Get("name").(string)
// Get the policy
p, lockType, err := b.lm.GetPolicy(req.Storage, name)
p, lock, err := b.lm.GetPolicyExclusive(req.Storage, name)
if lock != nil {
defer lock.Unlock()
}
if err != nil {
return nil, err
}
@ -39,28 +40,6 @@ func (b *backend) pathRotateWrite(
return logical.ErrorResponse("key not found"), logical.ErrInvalidRequest
}
// Store so we can detect later if this has changed out from under us
keyVersion := p.LatestVersion
b.lm.UnlockPolicy(name, lockType)
// Refresh in case it's changed since before we grabbed the lock
p, err = b.lm.RefreshPolicy(req.Storage, name)
if err != nil {
return nil, err
}
if p == nil {
return nil, fmt.Errorf("error finding key %s after locking for changes", name)
}
defer b.lm.UnlockPolicy(name, exclusive)
// Make sure that the policy hasn't been rotated simultaneously
if keyVersion != p.LatestVersion {
resp := &logical.Response{}
resp.AddWarning("key has been rotated since this endpoint was called; did not perform rotation")
return resp, nil
}
// Rotate the policy
err = p.rotate(req.Storage)