Don't generate an ID; use address for the ID. Generally speaking we'll need to sane against what's in the config

This commit is contained in:
Jeff Mitchell 2016-03-11 17:28:03 -05:00
parent ca40e06f5d
commit 14f538556e
4 changed files with 17 additions and 25 deletions

View file

@ -40,7 +40,7 @@ type ServerCommand struct {
Meta
ReloadFuncs map[string]server.ReloadFunc
ReloadFuncs []server.ReloadFunc
}
func (c *ServerCommand) Run(args []string) int {
@ -279,7 +279,7 @@ func (c *ServerCommand) Run(args []string) int {
// Initialize the listeners
lns := make([]net.Listener, 0, len(config.Listeners))
for i, lnConfig := range config.Listeners {
ln, props, reloadFactory, err := server.NewListener(lnConfig.Type, lnConfig.Config)
ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config)
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Error initializing listener of type %s: %s",
@ -301,9 +301,8 @@ func (c *ServerCommand) Run(args []string) int {
lns = append(lns, ln)
if reloadFactory != nil {
relId, relFunc := reloadFactory()
c.ReloadFuncs[relId] = relFunc
if reloadFunc != nil {
c.ReloadFuncs = append(c.ReloadFuncs, reloadFunc)
}
}
@ -578,9 +577,9 @@ func (c *ServerCommand) Reload(configPath []string) error {
// Call reload on the listeners. This will call each listener with each
// config block, but they verify the ID.
for _, lnConfig := range config.Listeners {
for id, relFunc := range c.ReloadFuncs {
if err := relFunc(id, lnConfig.Config); err != nil {
retErr := fmt.Errorf("Error encountered reloading configuration for %s: %s", id, err)
for _, relFunc := range c.ReloadFuncs {
if err := relFunc(lnConfig.Config); err != nil {
retErr := fmt.Errorf("Error encountered reloading configuration: %s", err)
reloadErrors = multierror.Append(retErr)
}
}

View file

@ -16,11 +16,7 @@ import (
)
// ReloadFunc are functions that are called when a reload is requested.
type ReloadFunc func(string, map[string]string) error
// ReloadFactory can be called to return the desired ID and the associated
// reload function.
type ReloadFactory func() (string, ReloadFunc)
type ReloadFunc func(map[string]string) error
// Config is the configuration for the vault server.
type Config struct {

View file

@ -12,7 +12,7 @@ import (
)
// ListenerFactory is the factory function to create a listener.
type ListenerFactory func(map[string]string) (net.Listener, map[string]string, ReloadFactory, error)
type ListenerFactory func(map[string]string) (net.Listener, map[string]string, ReloadFunc, error)
// BuiltinListeners is the list of built-in listener types.
var BuiltinListeners = map[string]ListenerFactory{
@ -28,7 +28,7 @@ var tlsLookup = map[string]uint16{
// NewListener creates a new listener of the given type with the given
// configuration. The type is looked up in the BuiltinListeners map.
func NewListener(t string, config map[string]string) (net.Listener, map[string]string, ReloadFactory, error) {
func NewListener(t string, config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) {
f, ok := BuiltinListeners[t]
if !ok {
return nil, nil, nil, fmt.Errorf("unknown listener type: %s", t)
@ -40,7 +40,7 @@ func NewListener(t string, config map[string]string) (net.Listener, map[string]s
func listenerWrapTLS(
ln net.Listener,
props map[string]string,
config map[string]string) (net.Listener, map[string]string, ReloadFactory, error) {
config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) {
props["tls"] = "disabled"
if v, ok := config["tls_disable"]; ok {
@ -64,10 +64,10 @@ func listenerWrapTLS(
}
cg := &certificateGetter{
id: "listen|" + ln.Addr().String(),
id: config["address"],
}
if err := cg.reload(cg.id, config); err != nil {
if err := cg.reload(config); err != nil {
return nil, nil, nil, fmt.Errorf("error loading TLS cert: %s", err)
}
@ -87,10 +87,7 @@ func listenerWrapTLS(
ln = tls.NewListener(ln, tlsConf)
props["tls"] = "enabled"
reloadFac := func() (string, ReloadFunc) {
return cg.id, cg.reload
}
return ln, props, reloadFac, nil
return ln, props, cg.reload, nil
}
type certificateGetter struct {
@ -101,8 +98,8 @@ type certificateGetter struct {
id string
}
func (cg *certificateGetter) reload(id string, config map[string]string) error {
if id != cg.id {
func (cg *certificateGetter) reload(config map[string]string) error {
if config["address"] != cg.id {
return nil
}

View file

@ -5,7 +5,7 @@ import (
"time"
)
func tcpListenerFactory(config map[string]string) (net.Listener, map[string]string, ReloadFactory, error) {
func tcpListenerFactory(config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) {
addr, ok := config["address"]
if !ok {
addr = "127.0.0.1:8200"