diff --git a/common/step_provision.go b/common/step_provision.go index 9c842044b..150a5c14e 100644 --- a/common/step_provision.go +++ b/common/step_provision.go @@ -22,7 +22,7 @@ type StepProvision struct { Comm packer.Communicator } -func (s *StepProvision) Run(_ context.Context, state multistep.StateBag) multistep.StepAction { +func (s *StepProvision) Run(ctx context.Context, state multistep.StateBag) multistep.StepAction { comm := s.Comm if comm == nil { raw, ok := state.Get("communicator").(packer.Communicator) @@ -38,7 +38,7 @@ func (s *StepProvision) Run(_ context.Context, state multistep.StateBag) multist log.Println("Running the provision hook") errCh := make(chan error, 1) go func() { - errCh <- hook.Run(packer.HookProvision, ui, comm, nil) + errCh <- hook.Run(ctx, packer.HookProvision, ui, comm, nil) }() for { @@ -53,7 +53,6 @@ func (s *StepProvision) Run(_ context.Context, state multistep.StateBag) multist case <-time.After(1 * time.Second): if _, ok := state.GetOk(multistep.StateCancelled); ok { log.Println("Cancelling provisioning due to interrupt...") - hook.Cancel() return multistep.ActionHalt } } diff --git a/packer/build_test.go b/packer/build_test.go index 057de0127..686f9ba40 100644 --- a/packer/build_test.go +++ b/packer/build_test.go @@ -192,7 +192,7 @@ func TestBuild_Run(t *testing.T) { // Verify hooks are dispatchable dispatchHook := builder.RunHook - dispatchHook.Run("foo", nil, nil, 42) + dispatchHook.Run(ctx, "foo", nil, nil, 42) hook := build.hooks["foo"][0].(*MockHook) if !hook.RunCalled { @@ -203,7 +203,7 @@ func TestBuild_Run(t *testing.T) { } // Verify provisioners run - dispatchHook.Run(HookProvision, nil, new(MockCommunicator), 42) + dispatchHook.Run(ctx, HookProvision, nil, new(MockCommunicator), 42) prov := build.provisioners[0].provisioner.(*MockProvisioner) if !prov.ProvCalled { t.Fatal("should be called") diff --git a/packer/builder.go b/packer/builder.go index 78aeae0bb..e24073128 100644 --- a/packer/builder.go +++ b/packer/builder.go @@ -1,5 +1,7 @@ package packer +import "context" + // Implementers of Builder are responsible for actually building images // on some platform given some configuration. // @@ -28,9 +30,5 @@ type Builder interface { Prepare(...interface{}) ([]string, error) // Run is where the actual build should take place. It takes a Build and a Ui. - Run(ui Ui, hook Hook) (Artifact, error) - - // Cancel cancels a possibly running Builder. This should block until - // the builder actually cancels and cleans up after itself. - Cancel() + Run(context.Context, Ui, Hook) (Artifact, error) } diff --git a/packer/builder_mock.go b/packer/builder_mock.go index fc2fd19db..8ae5bdbbe 100644 --- a/packer/builder_mock.go +++ b/packer/builder_mock.go @@ -1,6 +1,7 @@ package packer import ( + "context" "errors" ) @@ -27,7 +28,7 @@ func (tb *MockBuilder) Prepare(config ...interface{}) ([]string, error) { return tb.PrepareWarnings, nil } -func (tb *MockBuilder) Run(ui Ui, h Hook) (Artifact, error) { +func (tb *MockBuilder) Run(ctx context.Context, ui Ui, h Hook) (Artifact, error) { tb.RunCalled = true tb.RunHook = h tb.RunUi = ui @@ -41,7 +42,7 @@ func (tb *MockBuilder) Run(ui Ui, h Hook) (Artifact, error) { } if h != nil { - if err := h.Run(HookProvision, ui, new(MockCommunicator), nil); err != nil { + if err := h.Run(ctx, HookProvision, ui, new(MockCommunicator), nil); err != nil { return nil, err } } @@ -50,7 +51,3 @@ func (tb *MockBuilder) Run(ui Ui, h Hook) (Artifact, error) { IdValue: tb.ArtifactId, }, nil } - -func (tb *MockBuilder) Cancel() { - tb.CancelCalled = true -} diff --git a/packer/hook.go b/packer/hook.go index e5e7ad8a9..5c4d308b1 100644 --- a/packer/hook.go +++ b/packer/hook.go @@ -1,7 +1,7 @@ package packer import ( - "sync" + "context" ) // This is the hook that should be fired for provisioners to run. @@ -21,33 +21,18 @@ const HookProvision = "packer_provision" // must be race-free. Cancel should attempt to cancel the hook in the // quickest, safest way possible. type Hook interface { - Run(string, Ui, Communicator, interface{}) error - Cancel() + Run(context.Context, string, Ui, Communicator, interface{}) error } // A Hook implementation that dispatches based on an internal mapping. type DispatchHook struct { Mapping map[string][]Hook - - l sync.Mutex - cancelled bool - runningHook Hook } // Runs the hook with the given name by dispatching it to the proper // hooks if a mapping exists. If a mapping doesn't exist, then nothing // happens. -func (h *DispatchHook) Run(name string, ui Ui, comm Communicator, data interface{}) error { - h.l.Lock() - h.cancelled = false - h.l.Unlock() - - // Make sure when we exit that we reset the running hook. - defer func() { - h.l.Lock() - defer h.l.Unlock() - h.runningHook = nil - }() +func (h *DispatchHook) Run(ctx context.Context, name string, ui Ui, comm Communicator, data interface{}) error { hooks, ok := h.Mapping[name] if !ok { @@ -56,32 +41,14 @@ func (h *DispatchHook) Run(name string, ui Ui, comm Communicator, data interface } for _, hook := range hooks { - h.l.Lock() - if h.cancelled { - h.l.Unlock() - return nil + if err := ctx.Err(); err != nil { + return err } - h.runningHook = hook - h.l.Unlock() - - if err := hook.Run(name, ui, comm, data); err != nil { + if err := hook.Run(ctx, name, ui, comm, data); err != nil { return err } } return nil } - -// Cancels all the hooks that are currently in-flight, if any. This will -// block until the hooks are all cancelled. -func (h *DispatchHook) Cancel() { - h.l.Lock() - defer h.l.Unlock() - - if h.runningHook != nil { - h.runningHook.Cancel() - } - - h.cancelled = true -} diff --git a/packer/hook_mock.go b/packer/hook_mock.go index 7177329e3..ea51c7635 100644 --- a/packer/hook_mock.go +++ b/packer/hook_mock.go @@ -1,5 +1,10 @@ package packer +import ( + "context" + "time" +) + // MockHook is an implementation of Hook that can be used for tests. type MockHook struct { RunFunc func() error @@ -12,7 +17,16 @@ type MockHook struct { CancelCalled bool } -func (t *MockHook) Run(name string, ui Ui, comm Communicator, data interface{}) error { +func (t *MockHook) Run(ctx context.Context, name string, ui Ui, comm Communicator, data interface{}) error { + + go func() { + select { + case <-time.After(2 * time.Minute): + case <-ctx.Done(): + t.CancelCalled = true + } + }() + t.RunCalled = true t.RunComm = comm t.RunData = data @@ -25,7 +39,3 @@ func (t *MockHook) Run(name string, ui Ui, comm Communicator, data interface{}) return t.RunFunc() } - -func (t *MockHook) Cancel() { - t.CancelCalled = true -} diff --git a/packer/hook_test.go b/packer/hook_test.go index 3830cd26a..633220289 100644 --- a/packer/hook_test.go +++ b/packer/hook_test.go @@ -1,6 +1,7 @@ package packer import ( + "context" "sync" "testing" "time" @@ -15,7 +16,15 @@ type CancelHook struct { Cancelled bool } -func (h *CancelHook) Run(string, Ui, Communicator, interface{}) error { +func (h *CancelHook) Run(ctx context.Context, _ string, _ Ui, _ Communicator, _ interface{}) error { + go func() { + select { + case <-time.After(2 * time.Minute): + case <-ctx.Done(): + h.cancel() + } + }() + h.Lock() h.cancelCh = make(chan struct{}) h.doneCh = make(chan struct{}) @@ -32,7 +41,7 @@ func (h *CancelHook) Run(string, Ui, Communicator, interface{}) error { return nil } -func (h *CancelHook) Cancel() { +func (h *CancelHook) cancel() { h.Lock() close(h.cancelCh) h.Unlock() @@ -47,7 +56,7 @@ func TestDispatchHook_Implements(t *testing.T) { func TestDispatchHook_Run_NoHooks(t *testing.T) { // Just make sure nothing blows up dh := &DispatchHook{} - dh.Run("foo", nil, nil, nil) + dh.Run(context.Background(), "foo", nil, nil, nil) } func TestDispatchHook_Run(t *testing.T) { @@ -56,7 +65,7 @@ func TestDispatchHook_Run(t *testing.T) { mapping := make(map[string][]Hook) mapping["foo"] = []Hook{hook} dh := &DispatchHook{Mapping: mapping} - dh.Run("foo", nil, nil, 42) + dh.Run(context.Background(), "foo", nil, nil, 42) if !hook.RunCalled { t.Fatal("should be called") @@ -77,10 +86,10 @@ func TestDispatchHook_cancel(t *testing.T) { "foo": {hook}, }, } - - go dh.Run("foo", nil, nil, 42) + ctx, cancel := context.WithCancel(context.Background()) + go dh.Run(ctx, "foo", nil, nil, 42) time.Sleep(100 * time.Millisecond) - dh.Cancel() + cancel() if !hook.Cancelled { t.Fatal("hook should've cancelled") diff --git a/packer/plugin/hook.go b/packer/plugin/hook.go index ca28ecbee..e983e00af 100644 --- a/packer/plugin/hook.go +++ b/packer/plugin/hook.go @@ -1,6 +1,7 @@ package plugin import ( + "context" "log" "github.com/hashicorp/packer/packer" @@ -11,22 +12,13 @@ type cmdHook struct { client *Client } -func (c *cmdHook) Run(name string, ui packer.Ui, comm packer.Communicator, data interface{}) error { +func (c *cmdHook) Run(ctx context.Context, name string, ui packer.Ui, comm packer.Communicator, data interface{}) error { defer func() { r := recover() c.checkExit(r, nil) }() - return c.hook.Run(name, ui, comm, data) -} - -func (c *cmdHook) Cancel() { - defer func() { - r := recover() - c.checkExit(r, nil) - }() - - c.hook.Cancel() + return c.hook.Run(ctx, name, ui, comm, data) } func (c *cmdHook) checkExit(p interface{}, cb func()) { diff --git a/packer/provisioner_mock.go b/packer/provisioner_mock.go index 743a8e54a..01bee4c3c 100644 --- a/packer/provisioner_mock.go +++ b/packer/provisioner_mock.go @@ -1,6 +1,9 @@ package packer -import "context" +import ( + "context" + "time" +) // MockProvisioner is an implementation of Provisioner that can be // used for tests. @@ -22,6 +25,14 @@ func (t *MockProvisioner) Prepare(configs ...interface{}) error { } func (t *MockProvisioner) Provision(ctx context.Context, ui Ui, comm Communicator) error { + go func() { + select { + case <-time.After(2 * time.Minute): + case <-ctx.Done(): + t.CancelCalled = true + } + }() + t.ProvCalled = true t.ProvCommunicator = comm t.ProvUi = ui @@ -33,10 +44,6 @@ func (t *MockProvisioner) Provision(ctx context.Context, ui Ui, comm Communicato return t.ProvFunc() } -func (t *MockProvisioner) Cancel() { - t.CancelCalled = true -} - func (t *MockProvisioner) Communicator() Communicator { return t.ProvCommunicator } diff --git a/packer/provisioner_test.go b/packer/provisioner_test.go index 9d386f41a..224d201d3 100644 --- a/packer/provisioner_test.go +++ b/packer/provisioner_test.go @@ -30,7 +30,7 @@ func TestProvisionHook(t *testing.T) { }, } - hook.Run("foo", ui, comm, data) + hook.Run(context.Background(), "foo", ui, comm, data) if !pA.ProvCalled { t.Error("provision should be called on pA") @@ -56,7 +56,7 @@ func TestProvisionHook_nilComm(t *testing.T) { }, } - err := hook.Run("foo", ui, comm, data) + err := hook.Run(context.Background(), "foo", ui, comm, data) if err == nil { t.Fatal("should error") } @@ -83,16 +83,17 @@ func TestProvisionHook_cancel(t *testing.T) { {p, nil, ""}, }, } + ctx, cancel := context.WithCancel(context.Background()) finished := make(chan struct{}) go func() { - hook.Run("foo", nil, new(MockCommunicator), nil) + hook.Run(ctx, "foo", nil, new(MockCommunicator), nil) close(finished) }() // Cancel it while it is running time.Sleep(10 * time.Millisecond) - hook.Cancel() + cancel() lock.Lock() order = append(order, "cancel") lock.Unlock() @@ -187,13 +188,15 @@ func TestPausedProvisionerCancel(t *testing.T) { time.Sleep(10 * time.Millisecond) return nil } + ctx, cancel := context.WithCancel(context.Background()) // Start provisioning and wait for it to start - go prov.Provision(context.Background(), testUi(), new(MockCommunicator)) - <-provCh + go func() { + <-provCh + cancel() + }() - // Cancel it - prov.Cancel() + prov.Provision(ctx, testUi(), new(MockCommunicator)) if !mock.CancelCalled { t.Fatal("cancel should be called") } @@ -251,13 +254,15 @@ func TestDebuggedProvisionerCancel(t *testing.T) { time.Sleep(10 * time.Millisecond) return nil } + ctx, cancel := context.WithCancel(context.Background()) // Start provisioning and wait for it to start - go prov.Provision(context.Background(), testUi(), new(MockCommunicator)) - <-provCh + go func() { + <-provCh + cancel() + }() - // Cancel it - prov.Cancel() + prov.Provision(ctx, testUi(), new(MockCommunicator)) if !mock.CancelCalled { t.Fatal("cancel should be called") } diff --git a/packer/rpc/hook.go b/packer/rpc/hook.go index 623f86ce7..1ca368e45 100644 --- a/packer/rpc/hook.go +++ b/packer/rpc/hook.go @@ -1,6 +1,7 @@ package rpc import ( + "context" "log" "net/rpc" @@ -27,7 +28,7 @@ type HookRunArgs struct { StreamId uint32 } -func (h *hook) Run(name string, ui packer.Ui, comm packer.Communicator, data interface{}) error { +func (h *hook) Run(ctx context.Context, name string, ui packer.Ui, comm packer.Communicator, data interface{}) error { nextId := h.mux.NextId() server := newServerWithMux(h.mux, nextId) server.RegisterCommunicator(comm) @@ -50,22 +51,17 @@ func (h *hook) Cancel() { } } -func (h *HookServer) Run(args *HookRunArgs, reply *interface{}) error { +func (h *HookServer) Run(ctx context.Context, args *HookRunArgs, reply *interface{}) error { client, err := newClientWithMux(h.mux, args.StreamId) if err != nil { return NewBasicError(err) } defer client.Close() - if err := h.hook.Run(args.Name, client.Ui(), client.Communicator(), args.Data); err != nil { + if err := h.hook.Run(ctx, args.Name, client.Ui(), client.Communicator(), args.Data); err != nil { return NewBasicError(err) } *reply = nil return nil } - -func (h *HookServer) Cancel(args *interface{}, reply *interface{}) error { - h.hook.Cancel() - return nil -} diff --git a/packer/rpc/hook_test.go b/packer/rpc/hook_test.go index d1ae7ea06..f4a0b8169 100644 --- a/packer/rpc/hook_test.go +++ b/packer/rpc/hook_test.go @@ -1,6 +1,7 @@ package rpc import ( + "context" "reflect" "sync" "testing" @@ -12,6 +13,7 @@ import ( func TestHookRPC(t *testing.T) { // Create the UI to test h := new(packer.MockHook) + ctx, cancel := context.WithCancel(context.Background()) // Serve client, server := testClientServer(t) @@ -22,13 +24,13 @@ func TestHookRPC(t *testing.T) { // Test Run ui := &testUi{} - hClient.Run("foo", ui, nil, 42) + hClient.Run(ctx, "foo", ui, nil, 42) if !h.RunCalled { t.Fatal("should be called") } // Test Cancel - hClient.Cancel() + cancel() if !h.CancelCalled { t.Fatal("should be called") } @@ -52,6 +54,7 @@ func TestHook_cancelWhileRun(t *testing.T) { return nil }, } + ctx, cancel := context.WithCancel(context.Background()) // Serve client, server := testClientServer(t) @@ -63,13 +66,13 @@ func TestHook_cancelWhileRun(t *testing.T) { // Start the run finished := make(chan struct{}) go func() { - hClient.Run("foo", nil, nil, nil) + hClient.Run(ctx, "foo", nil, nil, nil) close(finished) }() // Cancel it pretty quickly. time.Sleep(10 * time.Millisecond) - hClient.Cancel() + cancel() finishLock.Lock() finishOrder = append(finishOrder, "cancel")