mirror of
https://github.com/hashicorp/vault.git
synced 2026-06-08 00:02:32 -04:00
Add reload capability for Vault listener certs. No tests (other than
manual) yet, and no documentation yet.
This commit is contained in:
parent
9e69eba911
commit
7e52796aae
5 changed files with 124 additions and 19 deletions
|
|
@ -274,7 +274,7 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
// Initialize the listeners
|
||||
lns := make([]net.Listener, 0, len(config.Listeners))
|
||||
for i, lnConfig := range config.Listeners {
|
||||
ln, props, err := server.NewListener(lnConfig.Type, lnConfig.Config)
|
||||
ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config)
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf(
|
||||
"Error initializing listener of type %s: %s",
|
||||
|
|
@ -295,6 +295,8 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
"%s (%s)", lnConfig.Type, strings.Join(propsList, ", "))
|
||||
|
||||
lns = append(lns, ln)
|
||||
|
||||
core.AddReloadFunc("listener-"+ln.Addr().String(), reloadFunc)
|
||||
}
|
||||
|
||||
infoKeys = append(infoKeys, "version")
|
||||
|
|
|
|||
|
|
@ -8,10 +8,13 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
// ListenerFactory is the factory function to create a listener.
|
||||
type ListenerFactory func(map[string]string) (net.Listener, map[string]string, error)
|
||||
type ListenerFactory func(map[string]string) (net.Listener, map[string]string, vault.ReloadFunc, error)
|
||||
|
||||
// BuiltinListeners is the list of built-in listener types.
|
||||
var BuiltinListeners = map[string]ListenerFactory{
|
||||
|
|
@ -27,10 +30,10 @@ var tlsLookup = map[string]uint16{
|
|||
|
||||
// NewListener creates a new listener of the given type with the given
|
||||
// configuration. The type is looked up in the BuiltinListeners map.
|
||||
func NewListener(t string, config map[string]string) (net.Listener, map[string]string, error) {
|
||||
func NewListener(t string, config map[string]string) (net.Listener, map[string]string, vault.ReloadFunc, error) {
|
||||
f, ok := BuiltinListeners[t]
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("unknown listener type: %s", t)
|
||||
return nil, nil, nil, fmt.Errorf("unknown listener type: %s", t)
|
||||
}
|
||||
|
||||
return f(config)
|
||||
|
|
@ -39,32 +42,35 @@ func NewListener(t string, config map[string]string) (net.Listener, map[string]s
|
|||
func listenerWrapTLS(
|
||||
ln net.Listener,
|
||||
props map[string]string,
|
||||
config map[string]string) (net.Listener, map[string]string, error) {
|
||||
config map[string]string) (net.Listener, map[string]string, vault.ReloadFunc, error) {
|
||||
props["tls"] = "disabled"
|
||||
|
||||
if v, ok := config["tls_disable"]; ok {
|
||||
disabled, err := strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid value for 'tls_disable': %v", err)
|
||||
return nil, nil, nil, fmt.Errorf("invalid value for 'tls_disable': %v", err)
|
||||
}
|
||||
if disabled {
|
||||
return ln, props, nil
|
||||
return ln, props, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
certFile, ok := config["tls_cert_file"]
|
||||
_, ok := config["tls_cert_file"]
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("'tls_cert_file' must be set")
|
||||
return nil, nil, nil, fmt.Errorf("'tls_cert_file' must be set")
|
||||
}
|
||||
|
||||
keyFile, ok := config["tls_key_file"]
|
||||
_, ok = config["tls_key_file"]
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("'tls_key_file' must be set")
|
||||
return nil, nil, nil, fmt.Errorf("'tls_key_file' must be set")
|
||||
}
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error loading TLS cert: %s", err)
|
||||
cg := &certificateGetter{
|
||||
config: config,
|
||||
}
|
||||
|
||||
if err := cg.reload(nil); err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error loading TLS cert: %s", err)
|
||||
}
|
||||
|
||||
tlsvers, ok := config["tls_min_version"]
|
||||
|
|
@ -73,15 +79,47 @@ func listenerWrapTLS(
|
|||
}
|
||||
|
||||
tlsConf := &tls.Config{}
|
||||
tlsConf.Certificates = []tls.Certificate{cert}
|
||||
tlsConf.GetCertificate = cg.getCertificate
|
||||
tlsConf.NextProtos = []string{"http/1.1"}
|
||||
tlsConf.MinVersion, ok = tlsLookup[tlsvers]
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("'tls_min_version' value %s not supported, please specify one of [tls10,tls11,tls12]", tlsvers)
|
||||
return nil, nil, nil, fmt.Errorf("'tls_min_version' value %s not supported, please specify one of [tls10,tls11,tls12]", tlsvers)
|
||||
}
|
||||
tlsConf.ClientAuth = tls.RequestClientCert
|
||||
|
||||
ln = tls.NewListener(ln, tlsConf)
|
||||
props["tls"] = "enabled"
|
||||
return ln, props, nil
|
||||
return ln, props, cg.reload, nil
|
||||
}
|
||||
|
||||
type certificateGetter struct {
|
||||
sync.RWMutex
|
||||
|
||||
config map[string]string
|
||||
cert *tls.Certificate
|
||||
}
|
||||
|
||||
func (cg *certificateGetter) reload(map[string]interface{}) error {
|
||||
cert, err := tls.LoadX509KeyPair(cg.config["tls_cert_file"], cg.config["tls_key_file"])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cg.Lock()
|
||||
defer cg.Unlock()
|
||||
|
||||
cg.cert = &cert
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cg *certificateGetter) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
cg.RLock()
|
||||
defer cg.RUnlock()
|
||||
|
||||
if cg.cert == nil {
|
||||
return nil, fmt.Errorf("nil certificate")
|
||||
}
|
||||
|
||||
return cg.cert, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@ package server
|
|||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
func tcpListenerFactory(config map[string]string) (net.Listener, map[string]string, error) {
|
||||
func tcpListenerFactory(config map[string]string) (net.Listener, map[string]string, vault.ReloadFunc, error) {
|
||||
addr, ok := config["address"]
|
||||
if !ok {
|
||||
addr = "127.0.0.1:8200"
|
||||
|
|
@ -13,7 +15,7 @@ func tcpListenerFactory(config map[string]string) (net.Listener, map[string]stri
|
|||
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
ln = tcpKeepAliveListener{ln.(*net.TCPListener)}
|
||||
|
|
|
|||
|
|
@ -169,6 +169,9 @@ type InitResult struct {
|
|||
RootToken string
|
||||
}
|
||||
|
||||
// ReloadFunc are functions that are called when a reload is requested.
|
||||
type ReloadFunc func(map[string]interface{}) error
|
||||
|
||||
// ErrInvalidKey is returned if there is an error with a
|
||||
// provided unseal key.
|
||||
type ErrInvalidKey struct {
|
||||
|
|
@ -232,6 +235,9 @@ type Core struct {
|
|||
rekeyProgress [][]byte
|
||||
rekeyLock sync.Mutex
|
||||
|
||||
// reloadFuncs is the list of functions to call due to a reload request
|
||||
reloadFuncs map[string]ReloadFunc
|
||||
|
||||
// mounts is loaded after unseal since it is a protected
|
||||
// configuration
|
||||
mounts *MountTable
|
||||
|
|
@ -1646,3 +1652,11 @@ func (c *Core) emitMetrics(stopCh chan struct{}) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddReloadFunc adds a reload func
|
||||
func (c *Core) AddReloadFunc(name string, reloadFunc ReloadFunc) {
|
||||
if c.reloadFuncs == nil {
|
||||
c.reloadFuncs = map[string]ReloadFunc{}
|
||||
}
|
||||
c.reloadFuncs[name] = reloadFunc
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
|
@ -35,11 +36,30 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) logical.Backend
|
|||
"audit",
|
||||
"audit/*",
|
||||
"raw/*",
|
||||
"reload",
|
||||
"rotate",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
&framework.Path{
|
||||
Pattern: "reload$",
|
||||
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"data": &framework.FieldSchema{
|
||||
Type: framework.TypeMap,
|
||||
Default: map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: b.handleReload,
|
||||
},
|
||||
|
||||
HelpSynopsis: strings.TrimSpace(sysHelp["reload"][0]),
|
||||
HelpDescription: strings.TrimSpace(sysHelp["reload"][1]),
|
||||
},
|
||||
|
||||
&framework.Path{
|
||||
Pattern: "rekey/backup$",
|
||||
|
||||
|
|
@ -1130,6 +1150,26 @@ func (b *SystemBackend) handleRotate(
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
// handleRreload is used to invoke configured reload functions
|
||||
func (b *SystemBackend) handleReload(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
reloadData := data.Get("data").(map[string]interface{})
|
||||
|
||||
var respErr *multierror.Error
|
||||
|
||||
for name, reloadFunc := range b.Core.reloadFuncs {
|
||||
err := reloadFunc(reloadData)
|
||||
if err != nil {
|
||||
b.Core.logger.Printf("[ERR] %s reload function was unsuccessful: %s", name, err)
|
||||
respErr = multierror.Append(respErr, err)
|
||||
} else {
|
||||
b.Core.logger.Printf("[INFO] %s reload function was successful", name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, respErr.ErrorOrNil()
|
||||
}
|
||||
|
||||
func sanitizeMountPath(path string) string {
|
||||
if !strings.HasSuffix(path, "/") {
|
||||
path += "/"
|
||||
|
|
@ -1399,4 +1439,13 @@ Enable a new audit backend or disable an existing backend.
|
|||
"Allows fetching or deleting the backup of the rotated unseal keys.",
|
||||
"",
|
||||
},
|
||||
|
||||
"reload": {
|
||||
"Allows reloading limited aspects of Vault's configuration dynamically.",
|
||||
`
|
||||
Reload allows limited aspects of Vault's configuration to be reloaded
|
||||
dynamically. Data given to the reload endpoint will be passed straight through
|
||||
to the various reload functions inside Vault.
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue