mirror of
https://github.com/hashicorp/vault.git
synced 2026-05-28 04:10:44 -04:00
Merge pull request #1346 from hashicorp/disable-all-caches
Disable all caches
This commit is contained in:
commit
3ca09fdf30
18 changed files with 715 additions and 439 deletions
|
|
@ -6,7 +6,7 @@ import (
|
|||
)
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
b := Backend(conf)
|
||||
be, err := b.Backend.Setup(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -15,7 +15,7 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
|||
return be, nil
|
||||
}
|
||||
|
||||
func Backend() *backend {
|
||||
func Backend(conf *logical.BackendConfig) *backend {
|
||||
var b backend
|
||||
b.Backend = &framework.Backend{
|
||||
Paths: []*framework.Path{
|
||||
|
|
@ -33,14 +33,12 @@ func Backend() *backend {
|
|||
Secrets: []*framework.Secret{},
|
||||
}
|
||||
|
||||
b.policies = policyCache{
|
||||
cache: map[string]*lockingPolicy{},
|
||||
}
|
||||
b.lm = newLockManager(conf.System.CachingDisabled())
|
||||
|
||||
return &b
|
||||
}
|
||||
|
||||
type backend struct {
|
||||
*framework.Backend
|
||||
policies policyCache
|
||||
lm *lockManager
|
||||
}
|
||||
|
|
|
|||
|
|
@ -559,16 +559,31 @@ func TestPolicyFuzzing(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
be := Backend()
|
||||
var be *backend
|
||||
sysView := logical.TestSystemView()
|
||||
|
||||
be = Backend(&logical.BackendConfig{
|
||||
System: sysView,
|
||||
})
|
||||
testPolicyFuzzingCommon(t, be)
|
||||
|
||||
sysView.CachingDisabledVal = true
|
||||
be = Backend(&logical.BackendConfig{
|
||||
System: sysView,
|
||||
})
|
||||
testPolicyFuzzingCommon(t, be)
|
||||
}
|
||||
|
||||
func testPolicyFuzzingCommon(t *testing.T, be *backend) {
|
||||
storage := &logical.LockingInmemStorage{}
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
funcs := []string{"encrypt", "decrypt", "rotate", "change_min_version"}
|
||||
//keys := []string{"test1", "test2", "test3", "test4", "test5"}
|
||||
keys := []string{"test1", "test2", "test3"}
|
||||
|
||||
// This is the goroutine loop
|
||||
doFuzzy := func() {
|
||||
doFuzzy := func(id int) {
|
||||
// Check for panics, otherwise notify we're done
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
|
|
@ -587,6 +602,9 @@ func TestPolicyFuzzing(t *testing.T) {
|
|||
}
|
||||
fd := &framework.FieldData{}
|
||||
|
||||
var chosenFunc, chosenKey string
|
||||
|
||||
//t.Errorf("Starting %d", id)
|
||||
for {
|
||||
// Stop after 10 seconds
|
||||
if time.Now().Sub(startTime) > 10*time.Second {
|
||||
|
|
@ -594,8 +612,8 @@ func TestPolicyFuzzing(t *testing.T) {
|
|||
}
|
||||
|
||||
// Pick a function and a key
|
||||
chosenFunc := funcs[rand.Int()%len(funcs)]
|
||||
chosenKey := keys[rand.Int()%len(keys)]
|
||||
chosenFunc = funcs[rand.Int()%len(funcs)]
|
||||
chosenKey = keys[rand.Int()%len(keys)]
|
||||
|
||||
fd.Raw = map[string]interface{}{
|
||||
"name": chosenKey,
|
||||
|
|
@ -605,33 +623,33 @@ func TestPolicyFuzzing(t *testing.T) {
|
|||
// Try to write the key to make sure it exists
|
||||
_, err := be.pathPolicyWrite(req, fd)
|
||||
if err != nil {
|
||||
t.Errorf("got an error: %v", err)
|
||||
return
|
||||
t.Fatalf("got an error: %v", err)
|
||||
}
|
||||
|
||||
switch chosenFunc {
|
||||
// Encrypt our plaintext and store the result
|
||||
case "encrypt":
|
||||
//t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id)
|
||||
fd.Raw["plaintext"] = base64.StdEncoding.EncodeToString([]byte(testPlaintext))
|
||||
fd.Schema = be.pathEncrypt().Fields
|
||||
resp, err := be.pathEncryptWrite(req, fd)
|
||||
if err != nil {
|
||||
t.Errorf("got an error: %v, resp is %#v", err, *resp)
|
||||
return
|
||||
t.Fatalf("got an error: %v, resp is %#v", err, *resp)
|
||||
}
|
||||
latestEncryptedText[chosenKey] = resp.Data["ciphertext"].(string)
|
||||
|
||||
// Rotate to a new key version
|
||||
case "rotate":
|
||||
//t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id)
|
||||
fd.Schema = be.pathRotate().Fields
|
||||
resp, err := be.pathRotateWrite(req, fd)
|
||||
if err != nil {
|
||||
t.Errorf("got an error: %v, resp is %#v", err, *resp)
|
||||
return
|
||||
t.Fatalf("got an error: %v, resp is %#v, chosenKey is %s", err, *resp, chosenKey)
|
||||
}
|
||||
|
||||
// Decrypt the ciphertext and compare the result
|
||||
case "decrypt":
|
||||
//t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id)
|
||||
ct := latestEncryptedText[chosenKey]
|
||||
if ct == "" {
|
||||
continue
|
||||
|
|
@ -645,13 +663,12 @@ func TestPolicyFuzzing(t *testing.T) {
|
|||
if resp.Data["error"].(string) == ErrTooOld {
|
||||
continue
|
||||
}
|
||||
t.Errorf("got an error: %v, resp is %#v, ciphertext was %s", err, *resp, latestEncryptedText[chosenKey])
|
||||
return
|
||||
t.Fatalf("got an error: %v, resp is %#v, ciphertext was %s, chosenKey is %s, id is %d", err, *resp, ct, chosenKey, id)
|
||||
}
|
||||
ptb64 := resp.Data["plaintext"].(string)
|
||||
pt, err := base64.StdEncoding.DecodeString(ptb64)
|
||||
if err != nil {
|
||||
t.Errorf("got an error decoding base64 plaintext: %v", err)
|
||||
t.Fatalf("got an error decoding base64 plaintext: %v", err)
|
||||
return
|
||||
}
|
||||
if string(pt) != testPlaintext {
|
||||
|
|
@ -660,10 +677,10 @@ func TestPolicyFuzzing(t *testing.T) {
|
|||
|
||||
// Change the min version, which also tests the archive functionality
|
||||
case "change_min_version":
|
||||
//t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id)
|
||||
resp, err := be.pathPolicyRead(req, fd)
|
||||
if err != nil {
|
||||
t.Errorf("got an error reading policy %s: %v", chosenKey, err)
|
||||
return
|
||||
t.Fatalf("got an error reading policy %s: %v", chosenKey, err)
|
||||
}
|
||||
latestVersion := resp.Data["latest_version"].(int)
|
||||
|
||||
|
|
@ -673,8 +690,7 @@ func TestPolicyFuzzing(t *testing.T) {
|
|||
fd.Schema = be.pathConfig().Fields
|
||||
resp, err = be.pathConfigWrite(req, fd)
|
||||
if err != nil {
|
||||
t.Errorf("got an error setting min decryption version: %v", err)
|
||||
return
|
||||
t.Fatalf("got an error setting min decryption version: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -683,7 +699,7 @@ func TestPolicyFuzzing(t *testing.T) {
|
|||
// Spawn 1000 of these workers for 10 seconds
|
||||
for i := 0; i < 1000; i++ {
|
||||
wg.Add(1)
|
||||
go doFuzzy()
|
||||
go doFuzzy(i)
|
||||
}
|
||||
|
||||
// Wait for them all to finish
|
||||
|
|
|
|||
325
builtin/logical/transit/lock_manager.go
Normal file
325
builtin/logical/transit/lock_manager.go
Normal file
|
|
@ -0,0 +1,325 @@
|
|||
package transit
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
const (
|
||||
shared = false
|
||||
exclusive = true
|
||||
)
|
||||
|
||||
var (
|
||||
errNeedExclusiveLock = errors.New("an exclusive lock is needed for this operation")
|
||||
)
|
||||
|
||||
type lockManager struct {
|
||||
// A lock for each named key
|
||||
locks map[string]*sync.RWMutex
|
||||
|
||||
// A mutex for the map itself
|
||||
locksMutex sync.RWMutex
|
||||
|
||||
// If caching is enabled, the map of name to in-memory policy cache
|
||||
cache map[string]*Policy
|
||||
|
||||
// Used for global locking, and as the cache map mutex
|
||||
cacheMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func newLockManager(cacheDisabled bool) *lockManager {
|
||||
lm := &lockManager{
|
||||
locks: map[string]*sync.RWMutex{},
|
||||
}
|
||||
if !cacheDisabled {
|
||||
lm.cache = map[string]*Policy{}
|
||||
}
|
||||
return lm
|
||||
}
|
||||
|
||||
func (lm *lockManager) CacheActive() bool {
|
||||
return lm.cache != nil
|
||||
}
|
||||
|
||||
func (lm *lockManager) policyLock(name string, lockType bool) *sync.RWMutex {
|
||||
lm.locksMutex.RLock()
|
||||
lock := lm.locks[name]
|
||||
if lock != nil {
|
||||
// We want to give this up before locking the lock, but it's safe --
|
||||
// the only time we ever write to a value in this map is the first time
|
||||
// we access the value, so it won't be changing out from under us
|
||||
lm.locksMutex.RUnlock()
|
||||
if lockType == exclusive {
|
||||
lock.Lock()
|
||||
} else {
|
||||
lock.RLock()
|
||||
}
|
||||
return lock
|
||||
}
|
||||
|
||||
lm.locksMutex.RUnlock()
|
||||
lm.locksMutex.Lock()
|
||||
|
||||
// Don't defer the unlock call because if we get a valid lock below we want
|
||||
// to release the lock mutex right away to avoid the possibility of
|
||||
// deadlock by trying to grab the second lock
|
||||
|
||||
// Check to make sure it hasn't been created since
|
||||
lock = lm.locks[name]
|
||||
if lock != nil {
|
||||
lm.locksMutex.Unlock()
|
||||
if lockType == exclusive {
|
||||
lock.Lock()
|
||||
} else {
|
||||
lock.RLock()
|
||||
}
|
||||
return lock
|
||||
}
|
||||
|
||||
lock = &sync.RWMutex{}
|
||||
lm.locks[name] = lock
|
||||
lm.locksMutex.Unlock()
|
||||
if lockType == exclusive {
|
||||
lock.Lock()
|
||||
} else {
|
||||
lock.RLock()
|
||||
}
|
||||
|
||||
return lock
|
||||
}
|
||||
|
||||
func (lm *lockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) {
|
||||
if lockType == exclusive {
|
||||
lock.Unlock()
|
||||
} else {
|
||||
lock.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Get the policy with a read lock. If we get an error saying an exclusive lock
|
||||
// is needed (for instance, for an upgrade/migration), give up the read lock,
|
||||
// call again with an exclusive lock, then swap back out for a read lock.
|
||||
func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
|
||||
p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, shared)
|
||||
if err == nil ||
|
||||
(err != nil && err != errNeedExclusiveLock) {
|
||||
return p, lock, err
|
||||
}
|
||||
|
||||
// Try again while asking for an exlusive lock
|
||||
p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, exclusive)
|
||||
if err != nil || p == nil || lock == nil {
|
||||
return p, lock, err
|
||||
}
|
||||
|
||||
lock.Unlock()
|
||||
|
||||
p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, shared)
|
||||
return p, lock, err
|
||||
}
|
||||
|
||||
// Get the policy with an exclusive lock
|
||||
func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
|
||||
p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, exclusive)
|
||||
return p, lock, err
|
||||
}
|
||||
|
||||
// Get the policy with a read lock; if it returns that an exclusive lock is
|
||||
// needed, retry. If successful, call one more time to get a read lock and
|
||||
// return the value.
|
||||
func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, derived bool) (*Policy, *sync.RWMutex, bool, error) {
|
||||
p, lock, _, err := lm.getPolicyCommon(storage, name, true, derived, shared)
|
||||
if err == nil ||
|
||||
(err != nil && err != errNeedExclusiveLock) {
|
||||
return p, lock, false, err
|
||||
}
|
||||
|
||||
// Try again while asking for an exlusive lock
|
||||
p, lock, upserted, err := lm.getPolicyCommon(storage, name, true, derived, exclusive)
|
||||
if err != nil || p == nil || lock == nil {
|
||||
return p, lock, upserted, err
|
||||
}
|
||||
|
||||
lock.Unlock()
|
||||
|
||||
// Now get a shared lock for the return, but preserve the value of upsert
|
||||
p, lock, _, err = lm.getPolicyCommon(storage, name, true, derived, shared)
|
||||
|
||||
return p, lock, upserted, err
|
||||
}
|
||||
|
||||
// When the function returns, a lock will be held on the policy if err == nil.
|
||||
// It is the caller's responsibility to unlock.
|
||||
func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, upsert, derived, lockType bool) (*Policy, *sync.RWMutex, bool, error) {
|
||||
lock := lm.policyLock(name, lockType)
|
||||
|
||||
var p *Policy
|
||||
var err error
|
||||
|
||||
// Check if it's in our cache. If so, return right away.
|
||||
if lm.CacheActive() {
|
||||
lm.cacheMutex.RLock()
|
||||
p = lm.cache[name]
|
||||
if p != nil {
|
||||
lm.cacheMutex.RUnlock()
|
||||
return p, lock, false, nil
|
||||
}
|
||||
lm.cacheMutex.RUnlock()
|
||||
}
|
||||
|
||||
// Load it from storage
|
||||
p, err = lm.getStoredPolicy(storage, name)
|
||||
if err != nil {
|
||||
lm.UnlockPolicy(lock, lockType)
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
if p == nil {
|
||||
// This is the only place we upsert a new policy, so if upsert is not
|
||||
// specified, or the lock type is wrong, unllock before returning
|
||||
if !upsert {
|
||||
lm.UnlockPolicy(lock, lockType)
|
||||
return nil, nil, false, nil
|
||||
}
|
||||
|
||||
if lockType != exclusive {
|
||||
lm.UnlockPolicy(lock, lockType)
|
||||
return nil, nil, false, errNeedExclusiveLock
|
||||
}
|
||||
|
||||
p = &Policy{
|
||||
Name: name,
|
||||
CipherMode: "aes-gcm",
|
||||
Derived: derived,
|
||||
}
|
||||
if derived {
|
||||
p.KDFMode = kdfMode
|
||||
}
|
||||
|
||||
err = p.rotate(storage)
|
||||
if err != nil {
|
||||
lm.UnlockPolicy(lock, lockType)
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
if lm.CacheActive() {
|
||||
// Since we didn't have the policy in the cache, if there was no
|
||||
// error, write the value in.
|
||||
lm.cacheMutex.Lock()
|
||||
defer lm.cacheMutex.Unlock()
|
||||
// Make sure a policy didn't appear. If so, it will only be set if
|
||||
// there was no error, so assume it's good and return that
|
||||
exp := lm.cache[name]
|
||||
if exp != nil {
|
||||
return exp, lock, false, nil
|
||||
}
|
||||
if err == nil {
|
||||
lm.cache[name] = p
|
||||
}
|
||||
}
|
||||
|
||||
// We don't need to worry about upgrading since it will be a new policy
|
||||
return p, lock, true, nil
|
||||
}
|
||||
|
||||
if p.needsUpgrade() {
|
||||
if lockType == shared {
|
||||
lm.UnlockPolicy(lock, lockType)
|
||||
return nil, nil, false, errNeedExclusiveLock
|
||||
}
|
||||
|
||||
err = p.upgrade(storage)
|
||||
if err != nil {
|
||||
lm.UnlockPolicy(lock, lockType)
|
||||
return nil, nil, false, err
|
||||
}
|
||||
}
|
||||
|
||||
if lm.CacheActive() {
|
||||
// Since we didn't have the policy in the cache, if there was no
|
||||
// error, write the value in.
|
||||
lm.cacheMutex.Lock()
|
||||
defer lm.cacheMutex.Unlock()
|
||||
// Make sure a policy didn't appear. If so, it will only be set if
|
||||
// there was no error, so assume it's good and return that
|
||||
exp := lm.cache[name]
|
||||
if exp != nil {
|
||||
return exp, lock, false, nil
|
||||
}
|
||||
if err == nil {
|
||||
lm.cache[name] = p
|
||||
}
|
||||
}
|
||||
|
||||
return p, lock, false, nil
|
||||
}
|
||||
|
||||
func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error {
|
||||
lm.cacheMutex.Lock()
|
||||
lock := lm.policyLock(name, exclusive)
|
||||
defer lock.Unlock()
|
||||
defer lm.cacheMutex.Unlock()
|
||||
|
||||
var p *Policy
|
||||
var err error
|
||||
|
||||
if lm.CacheActive() {
|
||||
p = lm.cache[name]
|
||||
}
|
||||
if p == nil {
|
||||
p, err = lm.getStoredPolicy(storage, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if p == nil {
|
||||
return fmt.Errorf("could not delete policy; not found")
|
||||
}
|
||||
}
|
||||
|
||||
if !p.DeletionAllowed {
|
||||
return fmt.Errorf("deletion is not allowed for this policy")
|
||||
}
|
||||
|
||||
err = storage.Delete("policy/" + name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deleting policy %s: %s", name, err)
|
||||
}
|
||||
|
||||
err = storage.Delete("archive/" + name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deleting archive %s: %s", name, err)
|
||||
}
|
||||
|
||||
if lm.CacheActive() {
|
||||
delete(lm.cache, name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lm *lockManager) getStoredPolicy(storage logical.Storage, name string) (*Policy, error) {
|
||||
// Check if the policy already exists
|
||||
raw, err := storage.Get("policy/" + name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if raw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Decode the policy
|
||||
policy := &Policy{
|
||||
Keys: KeyEntryMap{},
|
||||
}
|
||||
err = json.Unmarshal(raw.Value, policy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return policy, nil
|
||||
}
|
||||
|
|
@ -41,25 +41,20 @@ func (b *backend) pathConfigWrite(
|
|||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
// Check if the policy already exists
|
||||
lp, err := b.policies.getPolicy(req, name)
|
||||
// Check if the policy already exists before we lock everything
|
||||
p, lock, err := b.lm.GetPolicyExclusive(req.Storage, name)
|
||||
if lock != nil {
|
||||
defer lock.Unlock()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lp == nil {
|
||||
if p == nil {
|
||||
return logical.ErrorResponse(
|
||||
fmt.Sprintf("no existing role named %s could be found", name)),
|
||||
fmt.Sprintf("no existing key named %s could be found", name)),
|
||||
logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
lp.Lock()
|
||||
defer lp.Unlock()
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.policy == nil {
|
||||
return nil, fmt.Errorf("no existing role named %s could be found", name)
|
||||
}
|
||||
|
||||
resp := &logical.Response{}
|
||||
|
||||
persistNeeded := false
|
||||
|
|
@ -78,12 +73,12 @@ func (b *backend) pathConfigWrite(
|
|||
}
|
||||
|
||||
if minDecryptionVersion > 0 &&
|
||||
minDecryptionVersion != lp.policy.MinDecryptionVersion {
|
||||
if minDecryptionVersion > lp.policy.LatestVersion {
|
||||
minDecryptionVersion != p.MinDecryptionVersion {
|
||||
if minDecryptionVersion > p.LatestVersion {
|
||||
return logical.ErrorResponse(
|
||||
fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, lp.policy.LatestVersion)), nil
|
||||
fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, p.LatestVersion)), nil
|
||||
}
|
||||
lp.policy.MinDecryptionVersion = minDecryptionVersion
|
||||
p.MinDecryptionVersion = minDecryptionVersion
|
||||
persistNeeded = true
|
||||
}
|
||||
}
|
||||
|
|
@ -91,8 +86,8 @@ func (b *backend) pathConfigWrite(
|
|||
allowDeletionInt, ok := d.GetOk("deletion_allowed")
|
||||
if ok {
|
||||
allowDeletion := allowDeletionInt.(bool)
|
||||
if allowDeletion != lp.policy.DeletionAllowed {
|
||||
lp.policy.DeletionAllowed = allowDeletion
|
||||
if allowDeletion != p.DeletionAllowed {
|
||||
p.DeletionAllowed = allowDeletion
|
||||
persistNeeded = true
|
||||
}
|
||||
}
|
||||
|
|
@ -100,8 +95,8 @@ func (b *backend) pathConfigWrite(
|
|||
// Add this as a guard here before persisting since we now require the min
|
||||
// decryption version to start at 1; even if it's not explicitly set here,
|
||||
// force the upgrade
|
||||
if lp.policy.MinDecryptionVersion == 0 {
|
||||
lp.policy.MinDecryptionVersion = 1
|
||||
if p.MinDecryptionVersion == 0 {
|
||||
p.MinDecryptionVersion = 1
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
|
|
@ -109,7 +104,7 @@ func (b *backend) pathConfigWrite(
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
return resp, lp.policy.Persist(req.Storage)
|
||||
return resp, p.Persist(req.Storage)
|
||||
}
|
||||
|
||||
const pathConfigHelpSyn = `Configure a named encryption key`
|
||||
|
|
|
|||
|
|
@ -73,24 +73,17 @@ func (b *backend) pathDatakeyWrite(
|
|||
}
|
||||
|
||||
// Get the policy
|
||||
lp, err := b.policies.getPolicy(req, name)
|
||||
p, lock, err := b.lm.GetPolicyShared(req.Storage, name)
|
||||
if lock != nil {
|
||||
defer lock.RUnlock()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Error if invalid policy
|
||||
if lp == nil {
|
||||
if p == nil {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
lp.RLock()
|
||||
defer lp.RUnlock()
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.policy == nil {
|
||||
return nil, fmt.Errorf("no existing policy named %s could be found", name)
|
||||
}
|
||||
|
||||
newKey := make([]byte, 32)
|
||||
bits := d.Get("bits").(int)
|
||||
switch bits {
|
||||
|
|
@ -107,7 +100,7 @@ func (b *backend) pathDatakeyWrite(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
ciphertext, err := lp.policy.Encrypt(context, base64.StdEncoding.EncodeToString(newKey))
|
||||
ciphertext, err := p.Encrypt(context, base64.StdEncoding.EncodeToString(newKey))
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
|
|
|
|||
|
|
@ -58,25 +58,18 @@ func (b *backend) pathDecryptWrite(
|
|||
}
|
||||
|
||||
// Get the policy
|
||||
lp, err := b.policies.getPolicy(req, name)
|
||||
p, lock, err := b.lm.GetPolicyShared(req.Storage, name)
|
||||
if lock != nil {
|
||||
defer lock.RUnlock()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Error if invalid policy
|
||||
if lp == nil {
|
||||
if p == nil {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
lp.RLock()
|
||||
defer lp.RUnlock()
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.policy == nil {
|
||||
return nil, fmt.Errorf("no existing policy named %s could be found", name)
|
||||
}
|
||||
|
||||
plaintext, err := lp.policy.Decrypt(context, ciphertext)
|
||||
plaintext, err := p.Decrypt(context, ciphertext)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package transit
|
|||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
|
|
@ -44,12 +45,14 @@ func (b *backend) pathEncrypt() *framework.Path {
|
|||
func (b *backend) pathEncryptExistenceCheck(
|
||||
req *logical.Request, d *framework.FieldData) (bool, error) {
|
||||
name := d.Get("name").(string)
|
||||
lp, err := b.policies.getPolicy(req, name)
|
||||
p, lock, err := b.lm.GetPolicyShared(req.Storage, name)
|
||||
if lock != nil {
|
||||
defer lock.RUnlock()
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return lp != nil, nil
|
||||
return p != nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathEncryptWrite(
|
||||
|
|
@ -63,8 +66,8 @@ func (b *backend) pathEncryptWrite(
|
|||
// Decode the context if any
|
||||
contextRaw := d.Get("context").(string)
|
||||
var context []byte
|
||||
var err error
|
||||
if len(contextRaw) != 0 {
|
||||
var err error
|
||||
context, err = base64.StdEncoding.DecodeString(contextRaw)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest
|
||||
|
|
@ -72,37 +75,25 @@ func (b *backend) pathEncryptWrite(
|
|||
}
|
||||
|
||||
// Get the policy
|
||||
lp, err := b.policies.getPolicy(req, name)
|
||||
var p *Policy
|
||||
var lock *sync.RWMutex
|
||||
var upserted bool
|
||||
if req.Operation == logical.CreateOperation {
|
||||
p, lock, upserted, err = b.lm.GetPolicyUpsert(req.Storage, name, len(context) != 0)
|
||||
} else {
|
||||
p, lock, err = b.lm.GetPolicyShared(req.Storage, name)
|
||||
}
|
||||
if lock != nil {
|
||||
defer lock.RUnlock()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Error if invalid policy
|
||||
if lp == nil {
|
||||
if req.Operation != logical.CreateOperation {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
isDerived := len(context) != 0
|
||||
|
||||
lp, err = b.policies.generatePolicy(req.Storage, name, isDerived)
|
||||
// If the error is that the policy has been created in the interim we
|
||||
// will get the policy back, so only consider it an error if err is not
|
||||
// nil and we do not get a policy back
|
||||
if err != nil && lp != nil {
|
||||
return nil, err
|
||||
}
|
||||
if p == nil {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
lp.RLock()
|
||||
defer lp.RUnlock()
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.policy == nil {
|
||||
return nil, fmt.Errorf("no existing policy named %s could be found", name)
|
||||
}
|
||||
|
||||
ciphertext, err := lp.policy.Encrypt(context, value)
|
||||
ciphertext, err := p.Encrypt(context, value)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
|
|
@ -124,6 +115,11 @@ func (b *backend) pathEncryptWrite(
|
|||
"ciphertext": ciphertext,
|
||||
},
|
||||
}
|
||||
|
||||
if req.Operation == logical.CreateOperation && !upserted {
|
||||
resp.AddWarning("Attempted creation of the key during the encrypt operation, but it was created beforehand")
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -39,57 +39,57 @@ func (b *backend) pathPolicyWrite(
|
|||
name := d.Get("name").(string)
|
||||
derived := d.Get("derived").(bool)
|
||||
|
||||
// Check if the policy already exists
|
||||
existing, err := b.policies.getPolicy(req, name)
|
||||
p, lock, upserted, err := b.lm.GetPolicyUpsert(req.Storage, name, derived)
|
||||
if lock != nil {
|
||||
defer lock.RUnlock()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if existing != nil {
|
||||
return nil, nil
|
||||
if p == nil {
|
||||
return nil, fmt.Errorf("error generating key: returned policy was nil")
|
||||
}
|
||||
|
||||
// Generate the policy
|
||||
_, err = b.policies.generatePolicy(req.Storage, name, derived)
|
||||
return nil, err
|
||||
resp := &logical.Response{}
|
||||
if !upserted {
|
||||
resp.AddWarning(fmt.Sprintf("key %s already existed", name))
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathPolicyRead(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
lp, err := b.policies.getPolicy(req, name)
|
||||
p, lock, err := b.lm.GetPolicyShared(req.Storage, name)
|
||||
if lock != nil {
|
||||
defer lock.RUnlock()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lp == nil {
|
||||
if p == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lp.RLock()
|
||||
defer lp.RUnlock()
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.policy == nil {
|
||||
return nil, fmt.Errorf("no existing policy named %s could be found", name)
|
||||
}
|
||||
|
||||
// Return the response
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"name": lp.policy.Name,
|
||||
"cipher_mode": lp.policy.CipherMode,
|
||||
"derived": lp.policy.Derived,
|
||||
"deletion_allowed": lp.policy.DeletionAllowed,
|
||||
"min_decryption_version": lp.policy.MinDecryptionVersion,
|
||||
"latest_version": lp.policy.LatestVersion,
|
||||
"name": p.Name,
|
||||
"cipher_mode": p.CipherMode,
|
||||
"derived": p.Derived,
|
||||
"deletion_allowed": p.DeletionAllowed,
|
||||
"min_decryption_version": p.MinDecryptionVersion,
|
||||
"latest_version": p.LatestVersion,
|
||||
},
|
||||
}
|
||||
if lp.policy.Derived {
|
||||
resp.Data["kdf_mode"] = lp.policy.KDFMode
|
||||
if p.Derived {
|
||||
resp.Data["kdf_mode"] = p.KDFMode
|
||||
}
|
||||
|
||||
retKeys := map[string]int64{}
|
||||
for k, v := range lp.policy.Keys {
|
||||
for k, v := range p.Keys {
|
||||
retKeys[strconv.Itoa(k)] = v.CreationTime
|
||||
}
|
||||
resp.Data["keys"] = retKeys
|
||||
|
|
@ -101,15 +101,8 @@ func (b *backend) pathPolicyDelete(
|
|||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
lp, err := b.policies.getPolicy(req, name)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("error looking up policy %s, error is %s", name, err)), err
|
||||
}
|
||||
if lp == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("no such key %s", name)), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
err = b.policies.deletePolicy(req.Storage, name)
|
||||
// Delete does its own locking
|
||||
err := b.lm.DeletePolicy(req.Storage, name)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,25 +59,19 @@ func (b *backend) pathRewrapWrite(
|
|||
}
|
||||
|
||||
// Get the policy
|
||||
lp, err := b.policies.getPolicy(req, name)
|
||||
p, lock, err := b.lm.GetPolicyShared(req.Storage, name)
|
||||
if lock != nil {
|
||||
defer lock.RUnlock()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Error if invalid policy
|
||||
if lp == nil {
|
||||
if p == nil {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
lp.RLock()
|
||||
defer lp.RUnlock()
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.policy == nil {
|
||||
return nil, fmt.Errorf("no existing policy named %s could be found", name)
|
||||
}
|
||||
|
||||
plaintext, err := lp.policy.Decrypt(context, value)
|
||||
plaintext, err := p.Decrypt(context, value)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
|
|
@ -93,7 +87,7 @@ func (b *backend) pathRewrapWrite(
|
|||
return nil, fmt.Errorf("empty plaintext returned during rewrap")
|
||||
}
|
||||
|
||||
ciphertext, err := lp.policy.Encrypt(context, plaintext)
|
||||
ciphertext, err := p.Encrypt(context, plaintext)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
package transit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
|
@ -31,26 +29,19 @@ func (b *backend) pathRotateWrite(
|
|||
name := d.Get("name").(string)
|
||||
|
||||
// Get the policy
|
||||
lp, err := b.policies.getPolicy(req, name)
|
||||
p, lock, err := b.lm.GetPolicyExclusive(req.Storage, name)
|
||||
if lock != nil {
|
||||
defer lock.Unlock()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Error if invalid policy
|
||||
if lp == nil {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
if p == nil {
|
||||
return logical.ErrorResponse("key not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
lp.Lock()
|
||||
defer lp.Unlock()
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.policy == nil {
|
||||
return nil, fmt.Errorf("no existing policy named %s could be found", name)
|
||||
}
|
||||
|
||||
// Generate the policy
|
||||
err = lp.policy.rotate(req.Storage)
|
||||
// Rotate the policy
|
||||
err = p.rotate(req.Storage)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import (
|
|||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
|
|
@ -24,197 +23,6 @@ const (
|
|||
ErrTooOld = "ciphertext version is disallowed by policy (too old)"
|
||||
)
|
||||
|
||||
// policyCache implements a simple locking cache of policies
|
||||
type policyCache struct {
|
||||
sync.RWMutex
|
||||
cache map[string]*lockingPolicy
|
||||
}
|
||||
|
||||
// getPolicy loads a policy into the cache or returns one already in the cache
|
||||
func (p *policyCache) getPolicy(req *logical.Request, name string) (*lockingPolicy, error) {
|
||||
// We don't defer this since we may need to give it up and get a write lock
|
||||
p.RLock()
|
||||
|
||||
// First, see if we're in the cache -- if so, return that
|
||||
if p.cache[name] != nil {
|
||||
defer p.RUnlock()
|
||||
return p.cache[name], nil
|
||||
}
|
||||
|
||||
// If we didn't find anything, we'll need to write into the cache, plus possibly
|
||||
// persist the entry, so lock the cache
|
||||
p.RUnlock()
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
// Check one more time to ensure that another process did not write during
|
||||
// our lock switcheroo.
|
||||
if p.cache[name] != nil {
|
||||
return p.cache[name], nil
|
||||
}
|
||||
|
||||
// Note that we don't need to create the locking entry until the end,
|
||||
// because the policy wasn't in the cache so we don't know about it, and we
|
||||
// hold the cache lock so nothing else can be writing it in right now
|
||||
|
||||
// Check if the policy already exists
|
||||
raw, err := req.Storage.Get("policy/" + name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if raw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Decode the policy
|
||||
policy := &Policy{
|
||||
Keys: KeyEntryMap{},
|
||||
}
|
||||
err = json.Unmarshal(raw.Value, policy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
persistNeeded := false
|
||||
// Ensure we've moved from Key -> Keys
|
||||
if policy.Key != nil && len(policy.Key) > 0 {
|
||||
policy.migrateKeyToKeysMap()
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
// With archiving, past assumptions about the length of the keys map are no longer valid
|
||||
if policy.LatestVersion == 0 && len(policy.Keys) != 0 {
|
||||
policy.LatestVersion = len(policy.Keys)
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
// We disallow setting the version to 0, since they start at 1 since moving
|
||||
// to rotate-able keys, so update if it's set to 0
|
||||
if policy.MinDecryptionVersion == 0 {
|
||||
policy.MinDecryptionVersion = 1
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
// On first load after an upgrade, copy keys to the archive
|
||||
if policy.ArchiveVersion == 0 {
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
if persistNeeded {
|
||||
err = policy.Persist(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
lp := &lockingPolicy{
|
||||
policy: policy,
|
||||
}
|
||||
p.cache[name] = lp
|
||||
|
||||
return lp, nil
|
||||
}
|
||||
|
||||
// generatePolicy is used to create a new named policy with a randomly
|
||||
// generated key
|
||||
func (p *policyCache) generatePolicy(storage logical.Storage, name string, derived bool) (*lockingPolicy, error) {
|
||||
// Ensure one with this name doesn't already exist
|
||||
lp, err := p.getPolicy(&logical.Request{
|
||||
Storage: storage,
|
||||
}, name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error checking if policy already exists: %s", err)
|
||||
}
|
||||
if lp != nil {
|
||||
return nil, fmt.Errorf("policy %s already exists", name)
|
||||
}
|
||||
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
// Now we need to check again in the cache to ensure the policy wasn't
|
||||
// created since we checked getPolicy. A policy being created holds a write
|
||||
// lock until it's done, so it'll be in the cache at this point.
|
||||
if lp := p.cache[name]; lp != nil {
|
||||
return nil, fmt.Errorf("policy %s already exists", name)
|
||||
}
|
||||
|
||||
// Create the policy object
|
||||
policy := &Policy{
|
||||
Name: name,
|
||||
CipherMode: "aes-gcm",
|
||||
Derived: derived,
|
||||
}
|
||||
if derived {
|
||||
policy.KDFMode = kdfMode
|
||||
}
|
||||
|
||||
err = policy.rotate(storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lp = &lockingPolicy{
|
||||
policy: policy,
|
||||
}
|
||||
p.cache[name] = lp
|
||||
|
||||
// Return the policy
|
||||
return lp, nil
|
||||
}
|
||||
|
||||
// deletePolicy deletes a policy
|
||||
func (p *policyCache) deletePolicy(storage logical.Storage, name string) error {
|
||||
// Ensure one with this name exists
|
||||
lp, err := p.getPolicy(&logical.Request{
|
||||
Storage: storage,
|
||||
}, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking if policy already exists: %s", err)
|
||||
}
|
||||
if lp == nil {
|
||||
return fmt.Errorf("policy %s does not exist", name)
|
||||
}
|
||||
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
lp = p.cache[name]
|
||||
if lp == nil {
|
||||
return fmt.Errorf("policy %s not found", name)
|
||||
}
|
||||
|
||||
// We need to ensure all other access has stopped
|
||||
lp.Lock()
|
||||
defer lp.Unlock()
|
||||
|
||||
// Verify this hasn't changed
|
||||
if !lp.policy.DeletionAllowed {
|
||||
return fmt.Errorf("deletion not allowed for policy %s", name)
|
||||
}
|
||||
|
||||
err = storage.Delete("policy/" + name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deleting policy %s: %s", name, err)
|
||||
}
|
||||
|
||||
err = storage.Delete("archive/" + name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deleting archive %s: %s", name, err)
|
||||
}
|
||||
|
||||
lp.policy = nil
|
||||
delete(p.cache, name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// lockingPolicy holds a Policy guarded by a lock
|
||||
type lockingPolicy struct {
|
||||
sync.RWMutex
|
||||
policy *Policy
|
||||
}
|
||||
|
||||
// KeyEntry stores the key and metadata
|
||||
type KeyEntry struct {
|
||||
Key []byte `json:"key"`
|
||||
|
|
@ -427,6 +235,67 @@ func (p *Policy) Serialize() ([]byte, error) {
|
|||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
func (p *Policy) needsUpgrade() bool {
|
||||
// Ensure we've moved from Key -> Keys
|
||||
if p.Key != nil && len(p.Key) > 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// With archiving, past assumptions about the length of the keys map are no longer valid
|
||||
if p.LatestVersion == 0 && len(p.Keys) != 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// We disallow setting the version to 0, since they start at 1 since moving
|
||||
// to rotate-able keys, so update if it's set to 0
|
||||
if p.MinDecryptionVersion == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// On first load after an upgrade, copy keys to the archive
|
||||
if p.ArchiveVersion == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Policy) upgrade(storage logical.Storage) error {
|
||||
persistNeeded := false
|
||||
// Ensure we've moved from Key -> Keys
|
||||
if p.Key != nil && len(p.Key) > 0 {
|
||||
p.migrateKeyToKeysMap()
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
// With archiving, past assumptions about the length of the keys map are no longer valid
|
||||
if p.LatestVersion == 0 && len(p.Keys) != 0 {
|
||||
p.LatestVersion = len(p.Keys)
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
// We disallow setting the version to 0, since they start at 1 since moving
|
||||
// to rotate-able keys, so update if it's set to 0
|
||||
if p.MinDecryptionVersion == 0 {
|
||||
p.MinDecryptionVersion = 1
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
// On first load after an upgrade, copy keys to the archive
|
||||
if p.ArchiveVersion == 0 {
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
if persistNeeded {
|
||||
err := p.Persist(storage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeriveKey is used to derive the encryption key that should
|
||||
// be used depending on the policy. If derivation is disabled the
|
||||
// raw key is used and no context is required, otherwise the KDF
|
||||
|
|
@ -521,17 +390,17 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) {
|
|||
func (p *Policy) Decrypt(context []byte, value string) (string, error) {
|
||||
// Verify the prefix
|
||||
if !strings.HasPrefix(value, "vault:v") {
|
||||
return "", certutil.UserError{Err: "invalid ciphertext"}
|
||||
return "", certutil.UserError{Err: "invalid ciphertext: no prefix"}
|
||||
}
|
||||
|
||||
splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, "vault:v"), ":", 2)
|
||||
if len(splitVerCiphertext) != 2 {
|
||||
return "", certutil.UserError{Err: "invalid ciphertext"}
|
||||
return "", certutil.UserError{Err: "invalid ciphertext: wrong number of fields"}
|
||||
}
|
||||
|
||||
ver, err := strconv.Atoi(splitVerCiphertext[0])
|
||||
if err != nil {
|
||||
return "", certutil.UserError{Err: "invalid ciphertext"}
|
||||
return "", certutil.UserError{Err: "invalid ciphertext: version number could not be decoded"}
|
||||
}
|
||||
|
||||
if ver == 0 {
|
||||
|
|
@ -540,6 +409,10 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) {
|
|||
ver = 1
|
||||
}
|
||||
|
||||
if ver > p.LatestVersion {
|
||||
return "", certutil.UserError{Err: "invalid ciphertext: version is too new"}
|
||||
}
|
||||
|
||||
if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion {
|
||||
return "", certutil.UserError{Err: ErrTooOld}
|
||||
}
|
||||
|
|
@ -560,7 +433,7 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) {
|
|||
// Decode the base64
|
||||
decoded, err := base64.StdEncoding.DecodeString(splitVerCiphertext[1])
|
||||
if err != nil {
|
||||
return "", certutil.UserError{Err: "invalid ciphertext"}
|
||||
return "", certutil.UserError{Err: "invalid ciphertext: could not decode base64"}
|
||||
}
|
||||
|
||||
// Setup the cipher
|
||||
|
|
@ -582,7 +455,7 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) {
|
|||
// Verify and Decrypt
|
||||
plain, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", certutil.UserError{Err: "invalid ciphertext"}
|
||||
return "", certutil.UserError{Err: "invalid ciphertext: unable to decrypt"}
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(plain), nil
|
||||
|
|
@ -617,6 +490,8 @@ func (p *Policy) rotate(storage logical.Storage) error {
|
|||
p.MinDecryptionVersion = 1
|
||||
}
|
||||
|
||||
//fmt.Printf("policy %s rotated to %d\n", p.Name, p.LatestVersion)
|
||||
|
||||
return p.Persist(storage)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,38 +16,49 @@ func resetKeysArchive() {
|
|||
}
|
||||
|
||||
func Test_KeyUpgrade(t *testing.T) {
|
||||
testKeyUpgradeCommon(t, newLockManager(false))
|
||||
testKeyUpgradeCommon(t, newLockManager(true))
|
||||
}
|
||||
|
||||
func testKeyUpgradeCommon(t *testing.T, lm *lockManager) {
|
||||
storage := &logical.InmemStorage{}
|
||||
policies := &policyCache{
|
||||
cache: map[string]*lockingPolicy{},
|
||||
p, lock, upserted, err := lm.GetPolicyUpsert(storage, "test", false)
|
||||
if lock != nil {
|
||||
defer lock.RUnlock()
|
||||
}
|
||||
lp, err := policies.generatePolicy(storage, "test", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if lp == nil {
|
||||
if p == nil {
|
||||
t.Fatal("nil policy")
|
||||
}
|
||||
if !upserted {
|
||||
t.Fatal("expected an upsert")
|
||||
}
|
||||
|
||||
policy := lp.policy
|
||||
testBytes := make([]byte, len(p.Keys[1].Key))
|
||||
copy(testBytes, p.Keys[1].Key)
|
||||
|
||||
testBytes := make([]byte, len(policy.Keys[1].Key))
|
||||
copy(testBytes, policy.Keys[1].Key)
|
||||
|
||||
policy.Key = policy.Keys[1].Key
|
||||
policy.Keys = nil
|
||||
policy.migrateKeyToKeysMap()
|
||||
if policy.Key != nil {
|
||||
p.Key = p.Keys[1].Key
|
||||
p.Keys = nil
|
||||
p.migrateKeyToKeysMap()
|
||||
if p.Key != nil {
|
||||
t.Fatal("policy.Key is not nil")
|
||||
}
|
||||
if len(policy.Keys) != 1 {
|
||||
if len(p.Keys) != 1 {
|
||||
t.Fatal("policy.Keys is the wrong size")
|
||||
}
|
||||
if !reflect.DeepEqual(testBytes, policy.Keys[1].Key) {
|
||||
if !reflect.DeepEqual(testBytes, p.Keys[1].Key) {
|
||||
t.Fatal("key mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ArchivingUpgrade(t *testing.T) {
|
||||
testArchivingUpgradeCommon(t, newLockManager(false))
|
||||
testArchivingUpgradeCommon(t, newLockManager(true))
|
||||
}
|
||||
|
||||
func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
|
||||
resetKeysArchive()
|
||||
|
||||
// First, we generate a policy and rotate it a number of times. Each time
|
||||
|
|
@ -56,31 +67,27 @@ func Test_ArchivingUpgrade(t *testing.T) {
|
|||
// zero and latest, respectively
|
||||
|
||||
storage := &logical.InmemStorage{}
|
||||
policies := &policyCache{
|
||||
cache: map[string]*lockingPolicy{},
|
||||
}
|
||||
|
||||
lp, err := policies.generatePolicy(storage, "test", false)
|
||||
p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if lp == nil {
|
||||
t.Fatal("policy is nil")
|
||||
if p == nil || lock == nil {
|
||||
t.Fatal("nil policy or lock")
|
||||
}
|
||||
|
||||
policy := lp.policy
|
||||
lock.RUnlock()
|
||||
|
||||
// Store the initial key in the archive
|
||||
keysArchive = append(keysArchive, policy.Keys[1])
|
||||
checkKeys(t, policy, storage, "initial", 1, 1, 1)
|
||||
keysArchive = append(keysArchive, p.Keys[1])
|
||||
checkKeys(t, p, storage, "initial", 1, 1, 1)
|
||||
|
||||
for i := 2; i <= 10; i++ {
|
||||
err = policy.rotate(storage)
|
||||
err = p.rotate(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
keysArchive = append(keysArchive, policy.Keys[i])
|
||||
checkKeys(t, policy, storage, "rotate", i, i, i)
|
||||
keysArchive = append(keysArchive, p.Keys[i])
|
||||
checkKeys(t, p, storage, "rotate", i, i, i)
|
||||
}
|
||||
|
||||
// Now, wipe the archive and set the archive version to zero
|
||||
|
|
@ -88,44 +95,100 @@ func Test_ArchivingUpgrade(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
policy.ArchiveVersion = 0
|
||||
p.ArchiveVersion = 0
|
||||
|
||||
// Store it, but without calling persist, so we don't trigger
|
||||
// handleArchiving()
|
||||
buf, err := policy.Serialize()
|
||||
buf, err := p.Serialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Write the policy into storage
|
||||
err = storage.Put(&logical.StorageEntry{
|
||||
Key: "policy/" + policy.Name,
|
||||
Key: "policy/" + p.Name,
|
||||
Value: buf,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Expire from the cache since we modified it under-the-hood
|
||||
delete(policies.cache, "test")
|
||||
// If we're caching, expire from the cache since we modified it
|
||||
// under-the-hood
|
||||
if lm.CacheActive() {
|
||||
delete(lm.cache, "test")
|
||||
}
|
||||
|
||||
// Now get the policy again; the upgrade should happen automatically
|
||||
lp, err = policies.getPolicy(&logical.Request{
|
||||
Storage: storage,
|
||||
}, "test")
|
||||
p, lock, err = lm.GetPolicyShared(storage, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if lp == nil {
|
||||
t.Fatal("policy is nil")
|
||||
if p == nil || lock == nil {
|
||||
t.Fatal("nil policy or lock")
|
||||
}
|
||||
lock.RUnlock()
|
||||
|
||||
checkKeys(t, p, storage, "upgrade", 10, 10, 10)
|
||||
|
||||
// Let's check some deletion logic while we're at it
|
||||
|
||||
// The policy should be in there
|
||||
if lm.CacheActive() && lm.cache["test"] == nil {
|
||||
t.Fatal("nil policy in cache")
|
||||
}
|
||||
|
||||
policy = lp.policy
|
||||
// First we'll do this wrong, by not setting the deletion flag
|
||||
err = lm.DeletePolicy(storage, "test")
|
||||
if err == nil {
|
||||
t.Fatal("got nil error, but should not have been able to delete since we didn't set the deletion flag on the policy")
|
||||
}
|
||||
|
||||
checkKeys(t, policy, storage, "upgrade", 10, 10, 10)
|
||||
// The policy should still be in there
|
||||
if lm.CacheActive() && lm.cache["test"] == nil {
|
||||
t.Fatal("nil policy in cache")
|
||||
}
|
||||
|
||||
p, lock, err = lm.GetPolicyShared(storage, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if p == nil || lock == nil {
|
||||
t.Fatal("policy or lock nil after bad delete")
|
||||
}
|
||||
lock.RUnlock()
|
||||
|
||||
// Now do it properly
|
||||
p.DeletionAllowed = true
|
||||
err = p.Persist(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = lm.DeletePolicy(storage, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// The policy should *not* be in there
|
||||
if lm.CacheActive() && lm.cache["test"] != nil {
|
||||
t.Fatal("non-nil policy in cache")
|
||||
}
|
||||
|
||||
p, lock, err = lm.GetPolicyShared(storage, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if p != nil || lock != nil {
|
||||
t.Fatal("policy or lock not nil after delete")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Archiving(t *testing.T) {
|
||||
testArchivingCommon(t, newLockManager(false))
|
||||
testArchivingCommon(t, newLockManager(true))
|
||||
}
|
||||
|
||||
func testArchivingCommon(t *testing.T, lm *lockManager) {
|
||||
resetKeysArchive()
|
||||
|
||||
// First, we generate a policy and rotate it a number of times. Each time
|
||||
|
|
@ -135,38 +198,35 @@ func Test_Archiving(t *testing.T) {
|
|||
|
||||
storage := &logical.InmemStorage{}
|
||||
|
||||
policies := &policyCache{
|
||||
cache: map[string]*lockingPolicy{},
|
||||
p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false)
|
||||
if lock != nil {
|
||||
defer lock.RUnlock()
|
||||
}
|
||||
|
||||
lp, err := policies.generatePolicy(storage, "test", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if lp == nil {
|
||||
t.Fatal("policy is nil")
|
||||
if p == nil {
|
||||
t.Fatal("nil policy")
|
||||
}
|
||||
|
||||
policy := lp.policy
|
||||
|
||||
// Store the initial key in the archive
|
||||
keysArchive = append(keysArchive, policy.Keys[1])
|
||||
checkKeys(t, policy, storage, "initial", 1, 1, 1)
|
||||
keysArchive = append(keysArchive, p.Keys[1])
|
||||
checkKeys(t, p, storage, "initial", 1, 1, 1)
|
||||
|
||||
for i := 2; i <= 10; i++ {
|
||||
err = policy.rotate(storage)
|
||||
err = p.rotate(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
keysArchive = append(keysArchive, policy.Keys[i])
|
||||
checkKeys(t, policy, storage, "rotate", i, i, i)
|
||||
keysArchive = append(keysArchive, p.Keys[i])
|
||||
checkKeys(t, p, storage, "rotate", i, i, i)
|
||||
}
|
||||
|
||||
// Move the min decryption version up
|
||||
for i := 1; i <= 10; i++ {
|
||||
policy.MinDecryptionVersion = i
|
||||
p.MinDecryptionVersion = i
|
||||
|
||||
err = policy.Persist(storage)
|
||||
err = p.Persist(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -178,14 +238,14 @@ func Test_Archiving(t *testing.T) {
|
|||
// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
|
||||
// decryption version plus 1 (the min decryption version key
|
||||
// itself)
|
||||
checkKeys(t, policy, storage, "minadd", 10, 10, policy.LatestVersion-policy.MinDecryptionVersion+1)
|
||||
checkKeys(t, p, storage, "minadd", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1)
|
||||
}
|
||||
|
||||
// Move the min decryption version down
|
||||
for i := 10; i >= 1; i-- {
|
||||
policy.MinDecryptionVersion = i
|
||||
p.MinDecryptionVersion = i
|
||||
|
||||
err = policy.Persist(storage)
|
||||
err = p.Persist(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -197,7 +257,7 @@ func Test_Archiving(t *testing.T) {
|
|||
// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
|
||||
// decryption version plus 1 (the min decryption version key
|
||||
// itself)
|
||||
checkKeys(t, policy, storage, "minsub", 10, 10, policy.LatestVersion-policy.MinDecryptionVersion+1)
|
||||
checkKeys(t, p, storage, "minsub", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,10 @@ type SystemView interface {
|
|||
// when the stored CRL will be removed during the unmounting process
|
||||
// anyways), we can ignore the errors to allow unmounting to complete.
|
||||
Tainted() bool
|
||||
|
||||
// Returns true if caching is disabled. If true, no caches should be used,
|
||||
// despite known slowdowns.
|
||||
CachingDisabled() bool
|
||||
}
|
||||
|
||||
type StaticSystemView struct {
|
||||
|
|
@ -33,6 +37,7 @@ type StaticSystemView struct {
|
|||
MaxLeaseTTLVal time.Duration
|
||||
SudoPrivilegeVal bool
|
||||
TaintedVal bool
|
||||
CachingDisabledVal bool
|
||||
}
|
||||
|
||||
func (d StaticSystemView) DefaultLeaseTTL() time.Duration {
|
||||
|
|
@ -50,3 +55,7 @@ func (d StaticSystemView) SudoPrivilege(path string, token string) bool {
|
|||
func (d StaticSystemView) Tainted() bool {
|
||||
return d.TaintedVal
|
||||
}
|
||||
|
||||
func (d StaticSystemView) CachingDisabled() bool {
|
||||
return d.CachingDisabledVal
|
||||
}
|
||||
|
|
|
|||
|
|
@ -218,6 +218,9 @@ type Core struct {
|
|||
maxLeaseTTL time.Duration
|
||||
|
||||
logger *log.Logger
|
||||
|
||||
// cachingDisabled indicates whether caches are disabled
|
||||
cachingDisabled bool
|
||||
}
|
||||
|
||||
// CoreConfig is used to parameterize a core
|
||||
|
|
@ -315,6 +318,7 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
logger: conf.Logger,
|
||||
defaultLeaseTTL: conf.DefaultLeaseTTL,
|
||||
maxLeaseTTL: conf.MaxLeaseTTL,
|
||||
cachingDisabled: conf.DisableCache,
|
||||
}
|
||||
|
||||
// Setup the backends
|
||||
|
|
|
|||
|
|
@ -69,3 +69,8 @@ func (d dynamicSystemView) fetchTTLs() (def, max time.Duration) {
|
|||
func (d dynamicSystemView) Tainted() bool {
|
||||
return d.mountEntry.Tainted
|
||||
}
|
||||
|
||||
// CachingDisabled indicates whether to use caching behavior
|
||||
func (d dynamicSystemView) CachingDisabled() bool {
|
||||
return d.core.cachingDisabled
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,12 +34,15 @@ type PolicyEntry struct {
|
|||
|
||||
// NewPolicyStore creates a new PolicyStore that is backed
|
||||
// using a given view. It used used to durable store and manage named policy.
|
||||
func NewPolicyStore(view *BarrierView) *PolicyStore {
|
||||
cache, _ := lru.New2Q(policyCacheSize)
|
||||
func NewPolicyStore(view *BarrierView, system logical.SystemView) *PolicyStore {
|
||||
p := &PolicyStore{
|
||||
view: view,
|
||||
lru: cache,
|
||||
}
|
||||
if !system.CachingDisabled() {
|
||||
cache, _ := lru.New2Q(policyCacheSize)
|
||||
p.lru = cache
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
|
|
@ -50,7 +53,7 @@ func (c *Core) setupPolicyStore() error {
|
|||
view := c.systemBarrierView.SubView(policySubPath)
|
||||
|
||||
// Create the policy store
|
||||
c.policyStore = NewPolicyStore(view)
|
||||
c.policyStore = NewPolicyStore(view, &dynamicSystemView{core: c})
|
||||
|
||||
// Ensure that the default policy exists, and if not, create it
|
||||
policy, err := c.policyStore.GetPolicy("default")
|
||||
|
|
@ -95,23 +98,29 @@ func (ps *PolicyStore) SetPolicy(p *Policy) error {
|
|||
return fmt.Errorf("failed to persist policy: %v", err)
|
||||
}
|
||||
|
||||
// Update the LRU cache
|
||||
ps.lru.Add(p.Name, p)
|
||||
if ps.lru != nil {
|
||||
// Update the LRU cache
|
||||
ps.lru.Add(p.Name, p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPolicy is used to fetch the named policy
|
||||
func (ps *PolicyStore) GetPolicy(name string) (*Policy, error) {
|
||||
defer metrics.MeasureSince([]string{"policy", "get_policy"}, time.Now())
|
||||
// Check for cached policy
|
||||
if raw, ok := ps.lru.Get(name); ok {
|
||||
return raw.(*Policy), nil
|
||||
if ps.lru != nil {
|
||||
// Check for cached policy
|
||||
if raw, ok := ps.lru.Get(name); ok {
|
||||
return raw.(*Policy), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Special case the root policy
|
||||
if name == "root" {
|
||||
p := &Policy{Name: "root"}
|
||||
ps.lru.Add(p.Name, p)
|
||||
if ps.lru != nil {
|
||||
ps.lru.Add(p.Name, p)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
|
|
@ -152,8 +161,11 @@ func (ps *PolicyStore) GetPolicy(name string) (*Policy, error) {
|
|||
policy = p
|
||||
}
|
||||
|
||||
// Update the LRU cache
|
||||
ps.lru.Add(name, policy)
|
||||
if ps.lru != nil {
|
||||
// Update the LRU cache
|
||||
ps.lru.Add(name, policy)
|
||||
}
|
||||
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
|
|
@ -178,8 +190,10 @@ func (ps *PolicyStore) DeletePolicy(name string) error {
|
|||
return fmt.Errorf("failed to delete policy: %v", err)
|
||||
}
|
||||
|
||||
// Clear the cache
|
||||
ps.lru.Remove(name)
|
||||
if ps.lru != nil {
|
||||
// Clear the cache
|
||||
ps.lru.Remove(name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,16 @@ import (
|
|||
func mockPolicyStore(t *testing.T) *PolicyStore {
|
||||
_, barrier, _ := mockBarrier(t)
|
||||
view := NewBarrierView(barrier, "foo/")
|
||||
p := NewPolicyStore(view)
|
||||
p := NewPolicyStore(view, logical.TestSystemView())
|
||||
return p
|
||||
}
|
||||
|
||||
func mockPolicyStoreNoCache(t *testing.T) *PolicyStore {
|
||||
sysView := logical.TestSystemView()
|
||||
sysView.CachingDisabledVal = true
|
||||
_, barrier, _ := mockBarrier(t)
|
||||
view := NewBarrierView(barrier, "foo/")
|
||||
p := NewPolicyStore(view, sysView)
|
||||
return p
|
||||
}
|
||||
|
||||
|
|
@ -44,7 +53,13 @@ func TestPolicyStore_Root(t *testing.T) {
|
|||
|
||||
func TestPolicyStore_CRUD(t *testing.T) {
|
||||
ps := mockPolicyStore(t)
|
||||
testPolicyStore_CRUD(t, ps)
|
||||
|
||||
ps = mockPolicyStoreNoCache(t)
|
||||
testPolicyStore_CRUD(t, ps)
|
||||
}
|
||||
|
||||
func testPolicyStore_CRUD(t *testing.T, ps *PolicyStore) {
|
||||
// Get should return nothing
|
||||
p, err := ps.GetPolicy("dev")
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -50,9 +50,9 @@ sending a SIGHUP to the server process. These are denoted below.
|
|||
"tcp" is currently the only option available. A full reference for the
|
||||
inner syntax is below.
|
||||
|
||||
* `disable_cache` (optional) - A boolean. If true, this will disable the
|
||||
read cache used by the physical storage subsystem. This will very
|
||||
significantly impact performance.
|
||||
* `disable_cache` (optional) - A boolean. If true, this will disable all caches
|
||||
within Vault, including the read cache used by the physical storage
|
||||
subsystem. This will very significantly impact performance.
|
||||
|
||||
* `disable_mlock` (optional) - A boolean. If true, this will disable the
|
||||
server from executing the `mlock` syscall to prevent memory from being
|
||||
|
|
|
|||
Loading…
Reference in a new issue