diff --git a/command/server.go b/command/server.go index 3fc103304b..5f90b91d19 100644 --- a/command/server.go +++ b/command/server.go @@ -274,7 +274,7 @@ func (c *ServerCommand) Run(args []string) int { // Initialize the listeners lns := make([]net.Listener, 0, len(config.Listeners)) for i, lnConfig := range config.Listeners { - ln, props, err := server.NewListener(lnConfig.Type, lnConfig.Config) + ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config) if err != nil { c.Ui.Error(fmt.Sprintf( "Error initializing listener of type %s: %s", @@ -295,6 +295,8 @@ func (c *ServerCommand) Run(args []string) int { "%s (%s)", lnConfig.Type, strings.Join(propsList, ", ")) lns = append(lns, ln) + + core.AddReloadFunc("listener-"+ln.Addr().String(), reloadFunc) } infoKeys = append(infoKeys, "version") diff --git a/command/server/listener.go b/command/server/listener.go index 495f7ada0a..86d9f48d10 100644 --- a/command/server/listener.go +++ b/command/server/listener.go @@ -8,10 +8,13 @@ import ( "fmt" "net" "strconv" + "sync" + + "github.com/hashicorp/vault/vault" ) // ListenerFactory is the factory function to create a listener. -type ListenerFactory func(map[string]string) (net.Listener, map[string]string, error) +type ListenerFactory func(map[string]string) (net.Listener, map[string]string, vault.ReloadFunc, error) // BuiltinListeners is the list of built-in listener types. var BuiltinListeners = map[string]ListenerFactory{ @@ -27,10 +30,10 @@ var tlsLookup = map[string]uint16{ // NewListener creates a new listener of the given type with the given // configuration. The type is looked up in the BuiltinListeners map. -func NewListener(t string, config map[string]string) (net.Listener, map[string]string, error) { +func NewListener(t string, config map[string]string) (net.Listener, map[string]string, vault.ReloadFunc, error) { f, ok := BuiltinListeners[t] if !ok { - return nil, nil, fmt.Errorf("unknown listener type: %s", t) + return nil, nil, nil, fmt.Errorf("unknown listener type: %s", t) } return f(config) @@ -39,32 +42,35 @@ func NewListener(t string, config map[string]string) (net.Listener, map[string]s func listenerWrapTLS( ln net.Listener, props map[string]string, - config map[string]string) (net.Listener, map[string]string, error) { + config map[string]string) (net.Listener, map[string]string, vault.ReloadFunc, error) { props["tls"] = "disabled" if v, ok := config["tls_disable"]; ok { disabled, err := strconv.ParseBool(v) if err != nil { - return nil, nil, fmt.Errorf("invalid value for 'tls_disable': %v", err) + return nil, nil, nil, fmt.Errorf("invalid value for 'tls_disable': %v", err) } if disabled { - return ln, props, nil + return ln, props, nil, nil } } - certFile, ok := config["tls_cert_file"] + _, ok := config["tls_cert_file"] if !ok { - return nil, nil, fmt.Errorf("'tls_cert_file' must be set") + return nil, nil, nil, fmt.Errorf("'tls_cert_file' must be set") } - keyFile, ok := config["tls_key_file"] + _, ok = config["tls_key_file"] if !ok { - return nil, nil, fmt.Errorf("'tls_key_file' must be set") + return nil, nil, nil, fmt.Errorf("'tls_key_file' must be set") } - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return nil, nil, fmt.Errorf("error loading TLS cert: %s", err) + cg := &certificateGetter{ + config: config, + } + + if err := cg.reload(nil); err != nil { + return nil, nil, nil, fmt.Errorf("error loading TLS cert: %s", err) } tlsvers, ok := config["tls_min_version"] @@ -73,15 +79,47 @@ func listenerWrapTLS( } tlsConf := &tls.Config{} - tlsConf.Certificates = []tls.Certificate{cert} + tlsConf.GetCertificate = cg.getCertificate tlsConf.NextProtos = []string{"http/1.1"} tlsConf.MinVersion, ok = tlsLookup[tlsvers] if !ok { - return nil, nil, fmt.Errorf("'tls_min_version' value %s not supported, please specify one of [tls10,tls11,tls12]", tlsvers) + return nil, nil, nil, fmt.Errorf("'tls_min_version' value %s not supported, please specify one of [tls10,tls11,tls12]", tlsvers) } tlsConf.ClientAuth = tls.RequestClientCert ln = tls.NewListener(ln, tlsConf) props["tls"] = "enabled" - return ln, props, nil + return ln, props, cg.reload, nil +} + +type certificateGetter struct { + sync.RWMutex + + config map[string]string + cert *tls.Certificate +} + +func (cg *certificateGetter) reload(map[string]interface{}) error { + cert, err := tls.LoadX509KeyPair(cg.config["tls_cert_file"], cg.config["tls_key_file"]) + if err != nil { + return err + } + + cg.Lock() + defer cg.Unlock() + + cg.cert = &cert + + return nil +} + +func (cg *certificateGetter) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + cg.RLock() + defer cg.RUnlock() + + if cg.cert == nil { + return nil, fmt.Errorf("nil certificate") + } + + return cg.cert, nil } diff --git a/command/server/listener_tcp.go b/command/server/listener_tcp.go index f4f833dbcc..b7a0cd1f0b 100644 --- a/command/server/listener_tcp.go +++ b/command/server/listener_tcp.go @@ -3,9 +3,11 @@ package server import ( "net" "time" + + "github.com/hashicorp/vault/vault" ) -func tcpListenerFactory(config map[string]string) (net.Listener, map[string]string, error) { +func tcpListenerFactory(config map[string]string) (net.Listener, map[string]string, vault.ReloadFunc, error) { addr, ok := config["address"] if !ok { addr = "127.0.0.1:8200" @@ -13,7 +15,7 @@ func tcpListenerFactory(config map[string]string) (net.Listener, map[string]stri ln, err := net.Listen("tcp", addr) if err != nil { - return nil, nil, err + return nil, nil, nil, err } ln = tcpKeepAliveListener{ln.(*net.TCPListener)} diff --git a/vault/core.go b/vault/core.go index aeca810f07..b8a9487b26 100644 --- a/vault/core.go +++ b/vault/core.go @@ -169,6 +169,9 @@ type InitResult struct { RootToken string } +// ReloadFunc are functions that are called when a reload is requested. +type ReloadFunc func(map[string]interface{}) error + // ErrInvalidKey is returned if there is an error with a // provided unseal key. type ErrInvalidKey struct { @@ -232,6 +235,9 @@ type Core struct { rekeyProgress [][]byte rekeyLock sync.Mutex + // reloadFuncs is the list of functions to call due to a reload request + reloadFuncs map[string]ReloadFunc + // mounts is loaded after unseal since it is a protected // configuration mounts *MountTable @@ -1646,3 +1652,11 @@ func (c *Core) emitMetrics(stopCh chan struct{}) { } } } + +// AddReloadFunc adds a reload func +func (c *Core) AddReloadFunc(name string, reloadFunc ReloadFunc) { + if c.reloadFuncs == nil { + c.reloadFuncs = map[string]ReloadFunc{} + } + c.reloadFuncs[name] = reloadFunc +} diff --git a/vault/logical_system.go b/vault/logical_system.go index f4bf4aac56..7da2ce86f4 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -5,6 +5,7 @@ import ( "strings" "time" + "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" "github.com/mitchellh/mapstructure" @@ -35,11 +36,30 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) logical.Backend "audit", "audit/*", "raw/*", + "reload", "rotate", }, }, Paths: []*framework.Path{ + &framework.Path{ + Pattern: "reload$", + + Fields: map[string]*framework.FieldSchema{ + "data": &framework.FieldSchema{ + Type: framework.TypeMap, + Default: map[string]interface{}{}, + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: b.handleReload, + }, + + HelpSynopsis: strings.TrimSpace(sysHelp["reload"][0]), + HelpDescription: strings.TrimSpace(sysHelp["reload"][1]), + }, + &framework.Path{ Pattern: "rekey/backup$", @@ -1130,6 +1150,26 @@ func (b *SystemBackend) handleRotate( return nil, nil } +// handleRreload is used to invoke configured reload functions +func (b *SystemBackend) handleReload( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + reloadData := data.Get("data").(map[string]interface{}) + + var respErr *multierror.Error + + for name, reloadFunc := range b.Core.reloadFuncs { + err := reloadFunc(reloadData) + if err != nil { + b.Core.logger.Printf("[ERR] %s reload function was unsuccessful: %s", name, err) + respErr = multierror.Append(respErr, err) + } else { + b.Core.logger.Printf("[INFO] %s reload function was successful", name) + } + } + + return nil, respErr.ErrorOrNil() +} + func sanitizeMountPath(path string) string { if !strings.HasSuffix(path, "/") { path += "/" @@ -1399,4 +1439,13 @@ Enable a new audit backend or disable an existing backend. "Allows fetching or deleting the backup of the rotated unseal keys.", "", }, + + "reload": { + "Allows reloading limited aspects of Vault's configuration dynamically.", + ` +Reload allows limited aspects of Vault's configuration to be reloaded +dynamically. Data given to the reload endpoint will be passed straight through +to the various reload functions inside Vault. +`, + }, }