From 44f8cd1d03b19247dbefaf21935797364b8a7dac Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Tue, 13 Sep 2016 11:50:14 -0400 Subject: [PATCH] Rejig locks during unmount/remount. (#1855) --- vault/audit.go | 6 ++--- vault/auth.go | 8 +++---- vault/mount.go | 59 +++++++++++++++++++++-------------------------- vault/rollback.go | 15 +++++++----- 4 files changed, 42 insertions(+), 46 deletions(-) diff --git a/vault/audit.go b/vault/audit.go index f875d4d7ec..d5831e0b96 100644 --- a/vault/audit.go +++ b/vault/audit.go @@ -82,7 +82,7 @@ func (c *Core) enableAudit(entry *MountEntry) error { return err } - newTable := c.audit.ShallowClone() + newTable := c.audit.shallowClone() newTable.Entries = append(newTable.Entries, entry) if err := c.persistAudit(newTable); err != nil { return errors.New("failed to update audit table") @@ -109,8 +109,8 @@ func (c *Core) disableAudit(path string) error { c.auditLock.Lock() defer c.auditLock.Unlock() - newTable := c.audit.ShallowClone() - found := newTable.Remove(path) + newTable := c.audit.shallowClone() + found := newTable.remove(path) // Ensure there was a match if !found { diff --git a/vault/auth.go b/vault/auth.go index d98923f199..988a721019 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -81,7 +81,7 @@ func (c *Core) enableCredential(entry *MountEntry) error { } // Update the auth table - newTable := c.auth.ShallowClone() + newTable := c.auth.shallowClone() newTable.Entries = append(newTable.Entries, entry) if err := c.persistAuth(newTable); err != nil { return errors.New("failed to update auth table") @@ -162,8 +162,8 @@ func (c *Core) disableCredential(path string) error { // removeCredEntry is used to remove an entry in the auth table func (c *Core) removeCredEntry(path string) error { // Taint the entry from the auth table - newTable := c.auth.ShallowClone() - newTable.Remove(path) + newTable := c.auth.shallowClone() + newTable.remove(path) // Update the auth table if err := c.persistAuth(newTable); err != nil { @@ -180,7 +180,7 @@ func (c *Core) taintCredEntry(path string) error { // Taint the entry from the auth table // We do this on the original since setting the taint operates // on the entries which a shallow clone shares anyways - found := c.auth.SetTaint(path, true) + found := c.auth.setTaint(path, true) // Ensure there was a match if !found { diff --git a/vault/mount.go b/vault/mount.go index 9aab152fab..00e5c8ee53 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -64,11 +64,11 @@ type MountTable struct { Entries []*MountEntry `json:"entries"` } -// ShallowClone returns a copy of the mount table that +// shallowClone returns a copy of the mount table that // keeps the MountEntry locations, so as not to invalidate // other locations holding pointers. Care needs to be taken // if modifying entries rather than modifying the table itself -func (t *MountTable) ShallowClone() *MountTable { +func (t *MountTable) shallowClone() *MountTable { mt := &MountTable{ Type: t.Type, Entries: make([]*MountEntry, len(t.Entries)), @@ -89,19 +89,8 @@ func (t *MountTable) Hash() ([]byte, error) { return hash[:], nil } -// Find is used to lookup an entry -func (t *MountTable) Find(path string) *MountEntry { - n := len(t.Entries) - for i := 0; i < n; i++ { - if t.Entries[i].Path == path { - return t.Entries[i] - } - } - return nil -} - -// SetTaint is used to set the taint on given entry -func (t *MountTable) SetTaint(path string, value bool) bool { +// setTaint is used to set the taint on given entry +func (t *MountTable) setTaint(path string, value bool) bool { n := len(t.Entries) for i := 0; i < n; i++ { if t.Entries[i].Path == path { @@ -112,8 +101,8 @@ func (t *MountTable) SetTaint(path string, value bool) bool { return false } -// Remove is used to remove a given path entry -func (t *MountTable) Remove(path string) bool { +// remove is used to remove a given path entry +func (t *MountTable) remove(path string) bool { n := len(t.Entries) for i := 0; i < n; i++ { if t.Entries[i].Path == path { @@ -186,9 +175,6 @@ func (c *Core) mount(me *MountEntry) error { return logical.CodedError(409, fmt.Sprintf("existing mount at %s", match)) } - c.mountsLock.Lock() - defer c.mountsLock.Unlock() - // Generate a new UUID and view meUUID, err := uuid.GenerateUUID() if err != nil { @@ -203,12 +189,15 @@ func (c *Core) mount(me *MountEntry) error { } // Update the mount table - newTable := c.mounts.ShallowClone() + c.mountsLock.Lock() + newTable := c.mounts.shallowClone() newTable.Entries = append(newTable.Entries, me) if err := c.persistMounts(newTable); err != nil { + c.mountsLock.Unlock() return logical.CodedError(500, "failed to update mount table") } c.mounts = newTable + c.mountsLock.Unlock() // Mount the backend if err := c.router.Mount(backend, me.Path, me, view); err != nil { @@ -240,18 +229,16 @@ func (c *Core) unmount(path string) error { return fmt.Errorf("no matching mount") } - // Store the view for this backend + // Get the view for this backend view := c.router.MatchingStorageView(path) - c.mountsLock.Lock() - defer c.mountsLock.Unlock() - // Mark the entry as tainted if err := c.taintMountEntry(path); err != nil { return err } - // Taint the router path to prevent routing + // Taint the router path to prevent routing. Note that in-flight requests + // are uncertain, right now. if err := c.router.Taint(path); err != nil { return err } @@ -288,9 +275,12 @@ func (c *Core) unmount(path string) error { // removeMountEntry is used to remove an entry from the mount table func (c *Core) removeMountEntry(path string) error { + c.mountsLock.Lock() + defer c.mountsLock.Unlock() + // Remove the entry from the mount table - newTable := c.mounts.ShallowClone() - newTable.Remove(path) + newTable := c.mounts.shallowClone() + newTable.remove(path) // Update the mount table if err := c.persistMounts(newTable); err != nil { @@ -303,9 +293,12 @@ func (c *Core) removeMountEntry(path string) error { // taintMountEntry is used to mark an entry in the mount table as tainted func (c *Core) taintMountEntry(path string) error { + c.mountsLock.Lock() + defer c.mountsLock.Unlock() + // As modifying the taint of an entry affects shallow clones, // we simply use the original - c.mounts.SetTaint(path, true) + c.mounts.setTaint(path, true) // Update the mount table if err := c.persistMounts(c.mounts); err != nil { @@ -342,9 +335,6 @@ func (c *Core) remount(src, dst string) error { return fmt.Errorf("existing mount at '%s'", match) } - c.mountsLock.Lock() - defer c.mountsLock.Unlock() - // Mark the entry as tainted if err := c.taintMountEntry(src); err != nil { return err @@ -365,6 +355,7 @@ func (c *Core) remount(src, dst string) error { return err } + c.mountsLock.Lock() var ent *MountEntry for _, ent = range c.mounts.Entries { if ent.Path == src { @@ -378,8 +369,10 @@ func (c *Core) remount(src, dst string) error { if err := c.persistMounts(c.mounts); err != nil { ent.Path = src ent.Tainted = true + c.mountsLock.Unlock() return logical.CodedError(500, "failed to update mount table") } + c.mountsLock.Unlock() // Remount the backend if err := c.router.Remount(src, dst); err != nil { @@ -570,7 +563,7 @@ func (c *Core) unloadMounts() error { defer c.mountsLock.Unlock() if c.mounts != nil { - mountTable := c.mounts.ShallowClone() + mountTable := c.mounts.shallowClone() for _, e := range mountTable.Entries { prefix := e.Path b, ok := c.router.root.Get(prefix) diff --git a/vault/rollback.go b/vault/rollback.go index e7479cbb20..b72d424490 100644 --- a/vault/rollback.go +++ b/vault/rollback.go @@ -41,7 +41,7 @@ type RollbackManager struct { inflightAll sync.WaitGroup inflight map[string]*rollbackState - inflightLock sync.Mutex + inflightLock sync.RWMutex doneCh chan struct{} shutdown bool @@ -107,8 +107,6 @@ func (m *RollbackManager) run() { // triggerRollbacks is used to trigger the rollbacks across all the backends func (m *RollbackManager) triggerRollbacks() { - m.inflightLock.Lock() - defer m.inflightLock.Unlock() backends := m.backends() @@ -117,7 +115,10 @@ func (m *RollbackManager) triggerRollbacks() { if e.Table == credentialTableType { path = "auth/" + path } - if _, ok := m.inflight[path]; !ok { + m.inflightLock.RLock() + _, ok := m.inflight[path] + m.inflightLock.RUnlock() + if !ok { m.startRollback(path) } } @@ -129,7 +130,9 @@ func (m *RollbackManager) startRollback(path string) *rollbackState { rs := &rollbackState{} rs.Add(1) m.inflightAll.Add(1) + m.inflightLock.Lock() m.inflight[path] = rs + m.inflightLock.Unlock() go m.attemptRollback(path, rs) return rs } @@ -172,12 +175,12 @@ func (m *RollbackManager) attemptRollback(path string, rs *rollbackState) (err e // or to join an existing rollback operation if in flight. func (m *RollbackManager) Rollback(path string) error { // Check for an existing attempt and start one if none - m.inflightLock.Lock() + m.inflightLock.RLock() rs, ok := m.inflight[path] + m.inflightLock.RUnlock() if !ok { rs = m.startRollback(path) } - m.inflightLock.Unlock() // Wait for the attempt to finish rs.Wait()