Add reload capability for Vault listener certs. No tests (other than

manual) yet, and no documentation yet.
This commit is contained in:
Jeff Mitchell 2016-03-09 21:40:46 -05:00
parent 9e69eba911
commit 7e52796aae
5 changed files with 124 additions and 19 deletions

View file

@ -274,7 +274,7 @@ func (c *ServerCommand) Run(args []string) int {
// Initialize the listeners
lns := make([]net.Listener, 0, len(config.Listeners))
for i, lnConfig := range config.Listeners {
ln, props, err := server.NewListener(lnConfig.Type, lnConfig.Config)
ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config)
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Error initializing listener of type %s: %s",
@ -295,6 +295,8 @@ func (c *ServerCommand) Run(args []string) int {
"%s (%s)", lnConfig.Type, strings.Join(propsList, ", "))
lns = append(lns, ln)
core.AddReloadFunc("listener-"+ln.Addr().String(), reloadFunc)
}
infoKeys = append(infoKeys, "version")

View file

@ -8,10 +8,13 @@ import (
"fmt"
"net"
"strconv"
"sync"
"github.com/hashicorp/vault/vault"
)
// ListenerFactory is the factory function to create a listener.
type ListenerFactory func(map[string]string) (net.Listener, map[string]string, error)
type ListenerFactory func(map[string]string) (net.Listener, map[string]string, vault.ReloadFunc, error)
// BuiltinListeners is the list of built-in listener types.
var BuiltinListeners = map[string]ListenerFactory{
@ -27,10 +30,10 @@ var tlsLookup = map[string]uint16{
// NewListener creates a new listener of the given type with the given
// configuration. The type is looked up in the BuiltinListeners map.
func NewListener(t string, config map[string]string) (net.Listener, map[string]string, error) {
func NewListener(t string, config map[string]string) (net.Listener, map[string]string, vault.ReloadFunc, error) {
f, ok := BuiltinListeners[t]
if !ok {
return nil, nil, fmt.Errorf("unknown listener type: %s", t)
return nil, nil, nil, fmt.Errorf("unknown listener type: %s", t)
}
return f(config)
@ -39,32 +42,35 @@ func NewListener(t string, config map[string]string) (net.Listener, map[string]s
func listenerWrapTLS(
ln net.Listener,
props map[string]string,
config map[string]string) (net.Listener, map[string]string, error) {
config map[string]string) (net.Listener, map[string]string, vault.ReloadFunc, error) {
props["tls"] = "disabled"
if v, ok := config["tls_disable"]; ok {
disabled, err := strconv.ParseBool(v)
if err != nil {
return nil, nil, fmt.Errorf("invalid value for 'tls_disable': %v", err)
return nil, nil, nil, fmt.Errorf("invalid value for 'tls_disable': %v", err)
}
if disabled {
return ln, props, nil
return ln, props, nil, nil
}
}
certFile, ok := config["tls_cert_file"]
_, ok := config["tls_cert_file"]
if !ok {
return nil, nil, fmt.Errorf("'tls_cert_file' must be set")
return nil, nil, nil, fmt.Errorf("'tls_cert_file' must be set")
}
keyFile, ok := config["tls_key_file"]
_, ok = config["tls_key_file"]
if !ok {
return nil, nil, fmt.Errorf("'tls_key_file' must be set")
return nil, nil, nil, fmt.Errorf("'tls_key_file' must be set")
}
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, nil, fmt.Errorf("error loading TLS cert: %s", err)
cg := &certificateGetter{
config: config,
}
if err := cg.reload(nil); err != nil {
return nil, nil, nil, fmt.Errorf("error loading TLS cert: %s", err)
}
tlsvers, ok := config["tls_min_version"]
@ -73,15 +79,47 @@ func listenerWrapTLS(
}
tlsConf := &tls.Config{}
tlsConf.Certificates = []tls.Certificate{cert}
tlsConf.GetCertificate = cg.getCertificate
tlsConf.NextProtos = []string{"http/1.1"}
tlsConf.MinVersion, ok = tlsLookup[tlsvers]
if !ok {
return nil, nil, fmt.Errorf("'tls_min_version' value %s not supported, please specify one of [tls10,tls11,tls12]", tlsvers)
return nil, nil, nil, fmt.Errorf("'tls_min_version' value %s not supported, please specify one of [tls10,tls11,tls12]", tlsvers)
}
tlsConf.ClientAuth = tls.RequestClientCert
ln = tls.NewListener(ln, tlsConf)
props["tls"] = "enabled"
return ln, props, nil
return ln, props, cg.reload, nil
}
type certificateGetter struct {
sync.RWMutex
config map[string]string
cert *tls.Certificate
}
func (cg *certificateGetter) reload(map[string]interface{}) error {
cert, err := tls.LoadX509KeyPair(cg.config["tls_cert_file"], cg.config["tls_key_file"])
if err != nil {
return err
}
cg.Lock()
defer cg.Unlock()
cg.cert = &cert
return nil
}
func (cg *certificateGetter) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cg.RLock()
defer cg.RUnlock()
if cg.cert == nil {
return nil, fmt.Errorf("nil certificate")
}
return cg.cert, nil
}

View file

@ -3,9 +3,11 @@ package server
import (
"net"
"time"
"github.com/hashicorp/vault/vault"
)
func tcpListenerFactory(config map[string]string) (net.Listener, map[string]string, error) {
func tcpListenerFactory(config map[string]string) (net.Listener, map[string]string, vault.ReloadFunc, error) {
addr, ok := config["address"]
if !ok {
addr = "127.0.0.1:8200"
@ -13,7 +15,7 @@ func tcpListenerFactory(config map[string]string) (net.Listener, map[string]stri
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
ln = tcpKeepAliveListener{ln.(*net.TCPListener)}

View file

@ -169,6 +169,9 @@ type InitResult struct {
RootToken string
}
// ReloadFunc are functions that are called when a reload is requested.
type ReloadFunc func(map[string]interface{}) error
// ErrInvalidKey is returned if there is an error with a
// provided unseal key.
type ErrInvalidKey struct {
@ -232,6 +235,9 @@ type Core struct {
rekeyProgress [][]byte
rekeyLock sync.Mutex
// reloadFuncs is the list of functions to call due to a reload request
reloadFuncs map[string]ReloadFunc
// mounts is loaded after unseal since it is a protected
// configuration
mounts *MountTable
@ -1646,3 +1652,11 @@ func (c *Core) emitMetrics(stopCh chan struct{}) {
}
}
}
// AddReloadFunc adds a reload func
func (c *Core) AddReloadFunc(name string, reloadFunc ReloadFunc) {
if c.reloadFuncs == nil {
c.reloadFuncs = map[string]ReloadFunc{}
}
c.reloadFuncs[name] = reloadFunc
}

View file

@ -5,6 +5,7 @@ import (
"strings"
"time"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"github.com/mitchellh/mapstructure"
@ -35,11 +36,30 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) logical.Backend
"audit",
"audit/*",
"raw/*",
"reload",
"rotate",
},
},
Paths: []*framework.Path{
&framework.Path{
Pattern: "reload$",
Fields: map[string]*framework.FieldSchema{
"data": &framework.FieldSchema{
Type: framework.TypeMap,
Default: map[string]interface{}{},
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.handleReload,
},
HelpSynopsis: strings.TrimSpace(sysHelp["reload"][0]),
HelpDescription: strings.TrimSpace(sysHelp["reload"][1]),
},
&framework.Path{
Pattern: "rekey/backup$",
@ -1130,6 +1150,26 @@ func (b *SystemBackend) handleRotate(
return nil, nil
}
// handleRreload is used to invoke configured reload functions
func (b *SystemBackend) handleReload(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
reloadData := data.Get("data").(map[string]interface{})
var respErr *multierror.Error
for name, reloadFunc := range b.Core.reloadFuncs {
err := reloadFunc(reloadData)
if err != nil {
b.Core.logger.Printf("[ERR] %s reload function was unsuccessful: %s", name, err)
respErr = multierror.Append(respErr, err)
} else {
b.Core.logger.Printf("[INFO] %s reload function was successful", name)
}
}
return nil, respErr.ErrorOrNil()
}
func sanitizeMountPath(path string) string {
if !strings.HasSuffix(path, "/") {
path += "/"
@ -1399,4 +1439,13 @@ Enable a new audit backend or disable an existing backend.
"Allows fetching or deleting the backup of the rotated unseal keys.",
"",
},
"reload": {
"Allows reloading limited aspects of Vault's configuration dynamically.",
`
Reload allows limited aspects of Vault's configuration to be reloaded
dynamically. Data given to the reload endpoint will be passed straight through
to the various reload functions inside Vault.
`,
},
}