From 14f538556eeccc852b93f39cdd34066737e32fae Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Fri, 11 Mar 2016 17:28:03 -0500 Subject: [PATCH] Don't generate an ID; use address for the ID. Generally speaking we'll need to sane against what's in the config --- command/server.go | 15 +++++++-------- command/server/config.go | 6 +----- command/server/listener.go | 19 ++++++++----------- command/server/listener_tcp.go | 2 +- 4 files changed, 17 insertions(+), 25 deletions(-) diff --git a/command/server.go b/command/server.go index f0fc1750b5..12c604a12e 100644 --- a/command/server.go +++ b/command/server.go @@ -40,7 +40,7 @@ type ServerCommand struct { Meta - ReloadFuncs map[string]server.ReloadFunc + ReloadFuncs []server.ReloadFunc } func (c *ServerCommand) Run(args []string) int { @@ -279,7 +279,7 @@ func (c *ServerCommand) Run(args []string) int { // Initialize the listeners lns := make([]net.Listener, 0, len(config.Listeners)) for i, lnConfig := range config.Listeners { - ln, props, reloadFactory, err := server.NewListener(lnConfig.Type, lnConfig.Config) + ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config) if err != nil { c.Ui.Error(fmt.Sprintf( "Error initializing listener of type %s: %s", @@ -301,9 +301,8 @@ func (c *ServerCommand) Run(args []string) int { lns = append(lns, ln) - if reloadFactory != nil { - relId, relFunc := reloadFactory() - c.ReloadFuncs[relId] = relFunc + if reloadFunc != nil { + c.ReloadFuncs = append(c.ReloadFuncs, reloadFunc) } } @@ -578,9 +577,9 @@ func (c *ServerCommand) Reload(configPath []string) error { // Call reload on the listeners. This will call each listener with each // config block, but they verify the ID. for _, lnConfig := range config.Listeners { - for id, relFunc := range c.ReloadFuncs { - if err := relFunc(id, lnConfig.Config); err != nil { - retErr := fmt.Errorf("Error encountered reloading configuration for %s: %s", id, err) + for _, relFunc := range c.ReloadFuncs { + if err := relFunc(lnConfig.Config); err != nil { + retErr := fmt.Errorf("Error encountered reloading configuration: %s", err) reloadErrors = multierror.Append(retErr) } } diff --git a/command/server/config.go b/command/server/config.go index 340a237880..20e4434b37 100644 --- a/command/server/config.go +++ b/command/server/config.go @@ -16,11 +16,7 @@ import ( ) // ReloadFunc are functions that are called when a reload is requested. -type ReloadFunc func(string, map[string]string) error - -// ReloadFactory can be called to return the desired ID and the associated -// reload function. -type ReloadFactory func() (string, ReloadFunc) +type ReloadFunc func(map[string]string) error // Config is the configuration for the vault server. type Config struct { diff --git a/command/server/listener.go b/command/server/listener.go index 8f1ac90f74..bbe9d18bb6 100644 --- a/command/server/listener.go +++ b/command/server/listener.go @@ -12,7 +12,7 @@ import ( ) // ListenerFactory is the factory function to create a listener. -type ListenerFactory func(map[string]string) (net.Listener, map[string]string, ReloadFactory, error) +type ListenerFactory func(map[string]string) (net.Listener, map[string]string, ReloadFunc, error) // BuiltinListeners is the list of built-in listener types. var BuiltinListeners = map[string]ListenerFactory{ @@ -28,7 +28,7 @@ var tlsLookup = map[string]uint16{ // NewListener creates a new listener of the given type with the given // configuration. The type is looked up in the BuiltinListeners map. -func NewListener(t string, config map[string]string) (net.Listener, map[string]string, ReloadFactory, error) { +func NewListener(t string, config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) { f, ok := BuiltinListeners[t] if !ok { return nil, nil, nil, fmt.Errorf("unknown listener type: %s", t) @@ -40,7 +40,7 @@ func NewListener(t string, config map[string]string) (net.Listener, map[string]s func listenerWrapTLS( ln net.Listener, props map[string]string, - config map[string]string) (net.Listener, map[string]string, ReloadFactory, error) { + config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) { props["tls"] = "disabled" if v, ok := config["tls_disable"]; ok { @@ -64,10 +64,10 @@ func listenerWrapTLS( } cg := &certificateGetter{ - id: "listen|" + ln.Addr().String(), + id: config["address"], } - if err := cg.reload(cg.id, config); err != nil { + if err := cg.reload(config); err != nil { return nil, nil, nil, fmt.Errorf("error loading TLS cert: %s", err) } @@ -87,10 +87,7 @@ func listenerWrapTLS( ln = tls.NewListener(ln, tlsConf) props["tls"] = "enabled" - reloadFac := func() (string, ReloadFunc) { - return cg.id, cg.reload - } - return ln, props, reloadFac, nil + return ln, props, cg.reload, nil } type certificateGetter struct { @@ -101,8 +98,8 @@ type certificateGetter struct { id string } -func (cg *certificateGetter) reload(id string, config map[string]string) error { - if id != cg.id { +func (cg *certificateGetter) reload(config map[string]string) error { + if config["address"] != cg.id { return nil } diff --git a/command/server/listener_tcp.go b/command/server/listener_tcp.go index c68a29fd85..d4ba3aaff0 100644 --- a/command/server/listener_tcp.go +++ b/command/server/listener_tcp.go @@ -5,7 +5,7 @@ import ( "time" ) -func tcpListenerFactory(config map[string]string) (net.Listener, map[string]string, ReloadFactory, error) { +func tcpListenerFactory(config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) { addr, ok := config["address"] if !ok { addr = "127.0.0.1:8200"