diff --git a/vault/logical_system.go b/vault/logical_system.go index 52f1c7a0b3..89ec1a379d 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -3,6 +3,7 @@ package vault import ( "fmt" "strings" + "sync" "time" "github.com/hashicorp/vault/logical" @@ -845,6 +846,14 @@ func (b *SystemBackend) handleMountTuneWrite( return handleError(err) } + var lock *sync.RWMutex + switch { + case strings.HasPrefix(path, "auth/"): + lock = &b.Core.authLock + default: + lock = &b.Core.mountsLock + } + // Timing configuration parameters { var newDefault, newMax *time.Duration @@ -877,8 +886,9 @@ func (b *SystemBackend) handleMountTuneWrite( } if newDefault != nil || newMax != nil { - b.Core.mountsLock.Lock() - defer b.Core.mountsLock.Unlock() + lock.Lock() + defer lock.Unlock() + if err := b.tuneMountTTLs(path, &mountEntry.Config, newDefault, newMax); err != nil { b.Backend.Logger().Printf("[ERR] sys: tune of path '%s' failed: %v", path, err) return handleError(err) diff --git a/vault/logical_system_helpers.go b/vault/logical_system_helpers.go index b6d35ef76a..a3fb945794 100644 --- a/vault/logical_system_helpers.go +++ b/vault/logical_system_helpers.go @@ -1,8 +1,8 @@ package vault import ( - "errors" "fmt" + "strings" "time" ) @@ -51,6 +51,9 @@ func (b *SystemBackend) tuneMountTTLs(path string, meConfig *MountConfig, newDef } } + origMax := meConfig.MaxLeaseTTL + origDefault := meConfig.DefaultLeaseTTL + if newMax != nil { meConfig.MaxLeaseTTL = *newMax } @@ -59,8 +62,17 @@ func (b *SystemBackend) tuneMountTTLs(path string, meConfig *MountConfig, newDef } // Update the mount table - if err := b.Core.persistMounts(b.Core.mounts); err != nil { - return errors.New("failed to update mount table") + var err error + switch { + case strings.HasPrefix(path, "auth/"): + err = b.Core.persistAuth(b.Core.auth) + default: + err = b.Core.persistMounts(b.Core.mounts) + } + if err != nil { + meConfig.MaxLeaseTTL = origMax + meConfig.DefaultLeaseTTL = origDefault + return fmt.Errorf("failed to update mount table, rolling back TTL changes") } b.Core.logger.Printf("[INFO] core: tuned '%s'", path)