diff --git a/command/server.go b/command/server.go index 19f97b9a74..13fc449cec 100644 --- a/command/server.go +++ b/command/server.go @@ -13,6 +13,7 @@ import ( "sort" "strconv" "strings" + "sync" "syscall" "time" @@ -43,6 +44,8 @@ type ServerCommand struct { ShutdownCh chan struct{} SighupCh chan struct{} + WaitGroup *sync.WaitGroup + meta.Meta logger *log.Logger @@ -308,31 +311,6 @@ func (c *ServerCommand) Run(args []string) int { } } - // If the backend supports service discovery, run service discovery - if coreConfig.HAPhysical != nil && coreConfig.HAPhysical.HAEnabled() { - sd, ok := coreConfig.HAPhysical.(physical.ServiceDiscovery) - if ok { - activeFunc := func() bool { - if isLeader, _, err := core.Leader(); err == nil { - return isLeader - } - return false - } - - sealedFunc := func() bool { - if sealed, err := core.Sealed(); err == nil { - return sealed - } - return true - } - - if err := sd.RunServiceDiscovery(c.ShutdownCh, coreConfig.AdvertiseAddr, activeFunc, sealedFunc); err != nil { - c.Ui.Error(fmt.Sprintf("Error initializing service discovery: %v", err)) - return 1 - } - } - } - // Initialize the listeners lns := make([]net.Listener, 0, len(config.Listeners)) for i, lnConfig := range config.Listeners { @@ -392,6 +370,37 @@ func (c *ServerCommand) Run(args []string) int { return 0 } + // Perform service discovery registrations and initialization of + // HTTP server after the verifyOnly check. + + // Instantiate the wait group + c.WaitGroup = &sync.WaitGroup{} + + // If the backend supports service discovery, run service discovery + if coreConfig.HAPhysical != nil && coreConfig.HAPhysical.HAEnabled() { + sd, ok := coreConfig.HAPhysical.(physical.ServiceDiscovery) + if ok { + activeFunc := func() bool { + if isLeader, _, err := core.Leader(); err == nil { + return isLeader + } + return false + } + + sealedFunc := func() bool { + if sealed, err := core.Sealed(); err == nil { + return sealed + } + return true + } + + if err := sd.RunServiceDiscovery(c.WaitGroup, c.ShutdownCh, coreConfig.AdvertiseAddr, activeFunc, sealedFunc); err != nil { + c.Ui.Error(fmt.Sprintf("Error initializing service discovery: %v", err)) + return 1 + } + } + } + // Initialize the HTTP server server := &http.Server{} server.Handler = vaulthttp.Handler(core) @@ -428,6 +437,7 @@ func (c *ServerCommand) Run(args []string) int { } } + c.WaitGroup.Wait() return 0 } diff --git a/physical/consul.go b/physical/consul.go index dc10e5741a..c379dd64de 100644 --- a/physical/consul.go +++ b/physical/consul.go @@ -416,17 +416,19 @@ func (c *ConsulBackend) checkDuration() time.Duration { return lib.DurationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor) } -func (c *ConsulBackend) RunServiceDiscovery(shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) (err error) { +func (c *ConsulBackend) RunServiceDiscovery(waitGroup *sync.WaitGroup, shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) (err error) { if err := c.setAdvertiseAddr(advertiseAddr); err != nil { return err } - go c.runEventDemuxer(shutdownCh, advertiseAddr, activeFunc, sealedFunc) + waitGroup.Add(1) + + go c.runEventDemuxer(waitGroup, shutdownCh, advertiseAddr, activeFunc, sealedFunc) return nil } -func (c *ConsulBackend) runEventDemuxer(shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) { +func (c *ConsulBackend) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) { // Fire the reconcileTimer immediately upon starting the event demuxer reconcileTimer := time.NewTimer(0) defer reconcileTimer.Stop() @@ -516,6 +518,7 @@ shutdown: if err := c.client.Agent().ServiceDeregister(registeredServiceID); err != nil { c.logger.Printf("[WARN]: physical/consul: service deregistration failed: %v", err) } + defer waitGroup.Done() } // checkID returns the ID used for a Consul Check. Assume at least a read diff --git a/physical/physical.go b/physical/physical.go index ff74c9827d..9e96beb6d8 100644 --- a/physical/physical.go +++ b/physical/physical.go @@ -3,6 +3,7 @@ package physical import ( "fmt" "log" + "sync" ) const DefaultParallelOperations = 128 @@ -71,7 +72,7 @@ type ServiceDiscovery interface { // Run executes any background service discovery tasks until the // shutdown channel is closed. - RunServiceDiscovery(shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) error + RunServiceDiscovery(waitGroup *sync.WaitGroup, shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) error } type Lock interface {