diff --git a/common/step_provision.go b/common/step_provision.go index 150a5c14e..dfa756b72 100644 --- a/common/step_provision.go +++ b/common/step_provision.go @@ -50,6 +50,9 @@ func (s *StepProvision) Run(ctx context.Context, state multistep.StateBag) multi } return multistep.ActionContinue + case <-ctx.Done(): + log.Printf("Cancelling provisioning due to context cancellation: %s", ctx.Err()) + return multistep.ActionHalt case <-time.After(1 * time.Second): if _, ok := state.GetOk(multistep.StateCancelled); ok { log.Println("Cancelling provisioning due to interrupt...") diff --git a/helper/multistep/basic_runner.go b/helper/multistep/basic_runner.go index 465c730c3..6154d07e4 100644 --- a/helper/multistep/basic_runner.go +++ b/helper/multistep/basic_runner.go @@ -20,14 +20,11 @@ type BasicRunner struct { // modified. Steps []Step - cancel context.CancelFunc - doneCh chan struct{} - state runState - l sync.Mutex + l sync.Mutex + state runState } func (b *BasicRunner) Run(ctx context.Context, state StateBag) { - ctx, cancel := context.WithCancel(ctx) b.l.Lock() if b.state != stateIdle { @@ -35,15 +32,11 @@ func (b *BasicRunner) Run(ctx context.Context, state StateBag) { } doneCh := make(chan struct{}) - b.cancel = cancel - b.doneCh = doneCh b.state = stateRunning b.l.Unlock() defer func() { b.l.Lock() - b.cancel = nil - b.doneCh = nil b.state = stateIdle close(doneCh) b.l.Unlock() @@ -54,14 +47,16 @@ func (b *BasicRunner) Run(ctx context.Context, state StateBag) { go func() { select { case <-ctx.Done(): - // Flag cancel and wait for finish state.Put(StateCancelled, true) - <-doneCh case <-doneCh: } }() for _, step := range b.Steps { + if err := ctx.Err(); err != nil { + state.Put(StateCancelled, true) + break + } // We also check for cancellation here since we can't be sure // the goroutine that is running to set it actually ran. if runState(atomic.LoadInt32((*int32)(&b.state))) == stateCancelling { diff --git a/helper/multistep/basic_runner_test.go b/helper/multistep/basic_runner_test.go index e36fd3f92..51fbc54a9 100644 --- a/helper/multistep/basic_runner_test.go +++ b/helper/multistep/basic_runner_test.go @@ -4,7 +4,6 @@ import ( "context" "reflect" "testing" - "time" ) func TestBasicRunner_ImplRunner(t *testing.T) { @@ -99,39 +98,47 @@ func TestBasicRunner_Run_Run(t *testing.T) { } func TestBasicRunner_Cancel(t *testing.T) { - ch := make(chan chan bool) - data := new(BasicStateBag) - stepA := &TestStepAcc{Data: "a"} - stepB := &TestStepAcc{Data: "b"} - stepInt := &TestStepSync{ch} - stepC := &TestStepAcc{Data: "c"} - r := &BasicRunner{Steps: []Step{stepA, stepB, stepInt, stepC}} + topCtx, topCtxCancel := context.WithCancel(context.Background()) - ctx, cancel := context.WithCancel(context.Background()) - - go r.Run(ctx, data) - - // Wait until we reach the sync point - responseCh := <-ch - - // Cancel then continue chain - cancelCh := make(chan bool) - go func() { - cancel() - cancelCh <- true - }() - - for { - if _, ok := data.GetOk(StateCancelled); ok { - responseCh <- true - break + checkCancelled := func(data StateBag) { + cancelled := data.Get(StateCancelled).(bool) + if !cancelled { + t.Fatal("state should be cancelled") } - - time.Sleep(10 * time.Millisecond) } - <-cancelCh + data := new(BasicStateBag) + r := &BasicRunner{} + r.Steps = []Step{ + &TestStepAcc{Data: "a"}, + &TestStepAcc{Data: "b"}, + TestStepFn{ + run: func(ctx context.Context, sb StateBag) StepAction { + return ActionContinue + }, + cleanup: checkCancelled, + }, + TestStepFn{ + run: func(ctx context.Context, sb StateBag) StepAction { + topCtxCancel() + <-ctx.Done() + return ActionContinue + }, + cleanup: checkCancelled, + }, + TestStepFn{ + run: func(context.Context, StateBag) StepAction { + t.Fatal("I should not be called") + return ActionContinue + }, + cleanup: func(StateBag) { + t.Fatal("I should not be called") + }, + }, + } + + r.Run(topCtx, data) // Test run data expected := []string{"a", "b"} @@ -148,10 +155,8 @@ func TestBasicRunner_Cancel(t *testing.T) { } // Test that it says it is cancelled - cancelled := data.Get(StateCancelled).(bool) - if !cancelled { - t.Errorf("not cancelled") - } + checkCancelled(data) + } func TestBasicRunner_Cancel_Special(t *testing.T) { diff --git a/helper/multistep/debug_runner_test.go b/helper/multistep/debug_runner_test.go index f26123bdc..5cd744139 100644 --- a/helper/multistep/debug_runner_test.go +++ b/helper/multistep/debug_runner_test.go @@ -78,41 +78,64 @@ func TestDebugRunner_Run_Run(t *testing.T) { t.Errorf("Was able to run an already running DebugRunner") } +type TestStepFn struct { + run func(context.Context, StateBag) StepAction + cleanup func(StateBag) +} + +var _ Step = TestStepFn{} + +func (fn TestStepFn) Run(ctx context.Context, sb StateBag) StepAction { + return fn.run(ctx, sb) +} + +func (fn TestStepFn) Cleanup(sb StateBag) { + if fn.cleanup != nil { + fn.cleanup(sb) + } +} func TestDebugRunner_Cancel(t *testing.T) { - ch := make(chan chan bool) - data := new(BasicStateBag) - stepA := &TestStepAcc{Data: "a"} - stepB := &TestStepAcc{Data: "b"} - stepInt := &TestStepSync{ch} - stepC := &TestStepAcc{Data: "c"} - r := &DebugRunner{} - r.Steps = []Step{stepA, stepB, stepInt, stepC} + topCtx, topCtxCancel := context.WithCancel(context.Background()) - ctx, cancel := context.WithCancel(context.Background()) - - go r.Run(ctx, data) - - // Wait until we reach the sync point - responseCh := <-ch - - // Cancel then continue chain - cancelCh := make(chan bool) - go func() { - cancel() - cancelCh <- true - }() - - for { - if _, ok := data.GetOk(StateCancelled); ok { - responseCh <- true - break + checkCancelled := func(data StateBag) { + cancelled := data.Get(StateCancelled).(bool) + if !cancelled { + t.Fatal("state should be cancelled") } - - time.Sleep(10 * time.Millisecond) } - <-cancelCh + data := new(BasicStateBag) + r := &DebugRunner{} + r.Steps = []Step{ + &TestStepAcc{Data: "a"}, + &TestStepAcc{Data: "b"}, + TestStepFn{ + run: func(ctx context.Context, sb StateBag) StepAction { + return ActionContinue + }, + cleanup: checkCancelled, + }, + TestStepFn{ + run: func(ctx context.Context, sb StateBag) StepAction { + topCtxCancel() + <-ctx.Done() + return ActionContinue + }, + cleanup: checkCancelled, + }, + TestStepFn{ + run: func(context.Context, StateBag) StepAction { + t.Fatal("I should not be called") + return ActionContinue + }, + cleanup: func(StateBag) { + t.Fatal("I should not be called") + }, + }, + } + + r.Run(topCtx, data) // Test run data expected := []string{"a", "b"} @@ -129,8 +152,11 @@ func TestDebugRunner_Cancel(t *testing.T) { } // Test that it says it is cancelled - cancelled := data.Get(StateCancelled).(bool) - if !cancelled { + cancelled, ok := data.GetOk(StateCancelled) + if !ok { + t.Fatal("could not get state cancelled") + } + if !cancelled.(bool) { t.Errorf("not cancelled") } } diff --git a/helper/multistep/multistep_test.go b/helper/multistep/multistep_test.go index 3ced60d83..af5cd3368 100644 --- a/helper/multistep/multistep_test.go +++ b/helper/multistep/multistep_test.go @@ -58,7 +58,7 @@ func (s TestStepSync) Run(context.Context, StateBag) StepAction { return ActionContinue } -func (s TestStepSync) Cleanup(StateBag) {} +func (s TestStepSync) Cleanup(StateBag) { close(s.Ch) } func (s TestStepWaitForever) Run(context.Context, StateBag) StepAction { select {} diff --git a/packer/build_test.go b/packer/build_test.go index c581f2a2f..6bc6e4339 100644 --- a/packer/build_test.go +++ b/packer/build_test.go @@ -371,12 +371,18 @@ func TestBuild_RunBeforePrepare(t *testing.T) { func TestBuild_Cancel(t *testing.T) { build := testBuild() - ctx, cancel := context.WithCancel(context.Background()) - cancel() - build.Run(ctx, nil) + build.Prepare() + + topCtx, topCtxCancel := context.WithCancel(context.Background()) builder := build.builder.(*MockBuilder) - if !builder.CancelCalled { - t.Fatal("cancel should be called") + + builder.RunFn = func(ctx context.Context) { + topCtxCancel() + } + + _, err := build.Run(topCtx, testUi()) + if err == nil { + t.Fatal("build should err") } } diff --git a/packer/builder_mock.go b/packer/builder_mock.go index 8ae5bdbbe..5ebd2d993 100644 --- a/packer/builder_mock.go +++ b/packer/builder_mock.go @@ -20,6 +20,7 @@ type MockBuilder struct { RunHook Hook RunUi Ui CancelCalled bool + RunFn func(ctx context.Context) } func (tb *MockBuilder) Prepare(config ...interface{}) ([]string, error) { @@ -40,6 +41,9 @@ func (tb *MockBuilder) Run(ctx context.Context, ui Ui, h Hook) (Artifact, error) if tb.RunNilResult { return nil, nil } + if tb.RunFn != nil { + tb.RunFn(ctx) + } if h != nil { if err := h.Run(ctx, HookProvision, ui, new(MockCommunicator), nil); err != nil { diff --git a/packer/hook_mock.go b/packer/hook_mock.go index ea51c7635..16571f1fe 100644 --- a/packer/hook_mock.go +++ b/packer/hook_mock.go @@ -2,31 +2,21 @@ package packer import ( "context" - "time" ) // MockHook is an implementation of Hook that can be used for tests. type MockHook struct { - RunFunc func() error + RunFunc func(context.Context) error - RunCalled bool - RunComm Communicator - RunData interface{} - RunName string - RunUi Ui - CancelCalled bool + RunCalled bool + RunComm Communicator + RunData interface{} + RunName string + RunUi Ui } func (t *MockHook) Run(ctx context.Context, name string, ui Ui, comm Communicator, data interface{}) error { - go func() { - select { - case <-time.After(2 * time.Minute): - case <-ctx.Done(): - t.CancelCalled = true - } - }() - t.RunCalled = true t.RunComm = comm t.RunData = data @@ -37,5 +27,5 @@ func (t *MockHook) Run(ctx context.Context, name string, ui Ui, comm Communicato return nil } - return t.RunFunc() + return t.RunFunc(ctx) } diff --git a/packer/hook_test.go b/packer/hook_test.go index 633220289..efe0f0fd0 100644 --- a/packer/hook_test.go +++ b/packer/hook_test.go @@ -2,53 +2,9 @@ package packer import ( "context" - "sync" "testing" - "time" ) -// A helper Hook implementation for testing cancels. -type CancelHook struct { - sync.Mutex - cancelCh chan struct{} - doneCh chan struct{} - - Cancelled bool -} - -func (h *CancelHook) Run(ctx context.Context, _ string, _ Ui, _ Communicator, _ interface{}) error { - go func() { - select { - case <-time.After(2 * time.Minute): - case <-ctx.Done(): - h.cancel() - } - }() - - h.Lock() - h.cancelCh = make(chan struct{}) - h.doneCh = make(chan struct{}) - h.Unlock() - - defer close(h.doneCh) - - select { - case <-h.cancelCh: - h.Cancelled = true - case <-time.After(1 * time.Second): - } - - return nil -} - -func (h *CancelHook) cancel() { - h.Lock() - close(h.cancelCh) - h.Unlock() - - <-h.doneCh -} - func TestDispatchHook_Implements(t *testing.T) { var _ Hook = new(DispatchHook) } @@ -78,20 +34,36 @@ func TestDispatchHook_Run(t *testing.T) { } } +// A helper Hook implementation for testing cancels. +// Run will wait indetinitelly until ctx is cancelled. +type CancelHook struct { + cancel func() +} + +func (h *CancelHook) Run(ctx context.Context, _ string, _ Ui, _ Communicator, _ interface{}) error { + h.cancel() + <-ctx.Done() + return ctx.Err() +} + func TestDispatchHook_cancel(t *testing.T) { - hook := new(CancelHook) + + cancelHook := new(CancelHook) dh := &DispatchHook{ Mapping: map[string][]Hook{ - "foo": {hook}, + "foo": {cancelHook}, }, } ctx, cancel := context.WithCancel(context.Background()) - go dh.Run(ctx, "foo", nil, nil, 42) - time.Sleep(100 * time.Millisecond) - cancel() + cancelHook.cancel = cancel - if !hook.Cancelled { - t.Fatal("hook should've cancelled") + errchan := make(chan error) + go func() { + errchan <- dh.Run(ctx, "foo", nil, nil, 42) + }() + + if err := <-errchan; err == nil { + t.Fatal("hook should've errored") } } diff --git a/packer/plugin/server.go b/packer/plugin/server.go index 22e74847e..0efefb7e5 100644 --- a/packer/plugin/server.go +++ b/packer/plugin/server.go @@ -101,7 +101,7 @@ func Server() (*packrpc.Server, error) { // Serve a single connection log.Println("Serving a plugin connection...") - return packrpc.NewServer(conn), nil + return packrpc.NewServer(conn) } func serverListener(minPort, maxPort int64) (net.Listener, error) { diff --git a/packer/provisioner_mock.go b/packer/provisioner_mock.go index 01bee4c3c..9555c5f57 100644 --- a/packer/provisioner_mock.go +++ b/packer/provisioner_mock.go @@ -2,20 +2,18 @@ package packer import ( "context" - "time" ) // MockProvisioner is an implementation of Provisioner that can be // used for tests. type MockProvisioner struct { - ProvFunc func() error + ProvFunc func(context.Context) error PrepCalled bool PrepConfigs []interface{} ProvCalled bool ProvCommunicator Communicator ProvUi Ui - CancelCalled bool } func (t *MockProvisioner) Prepare(configs ...interface{}) error { @@ -25,14 +23,6 @@ func (t *MockProvisioner) Prepare(configs ...interface{}) error { } func (t *MockProvisioner) Provision(ctx context.Context, ui Ui, comm Communicator) error { - go func() { - select { - case <-time.After(2 * time.Minute): - case <-ctx.Done(): - t.CancelCalled = true - } - }() - t.ProvCalled = true t.ProvCommunicator = comm t.ProvUi = ui @@ -41,7 +31,7 @@ func (t *MockProvisioner) Provision(ctx context.Context, ui Ui, comm Communicato return nil } - return t.ProvFunc() + return t.ProvFunc(ctx) } func (t *MockProvisioner) Communicator() Communicator { diff --git a/packer/provisioner_test.go b/packer/provisioner_test.go index 224d201d3..ac6ab591e 100644 --- a/packer/provisioner_test.go +++ b/packer/provisioner_test.go @@ -2,7 +2,6 @@ package packer import ( "context" - "sync" "testing" "time" ) @@ -63,18 +62,13 @@ func TestProvisionHook_nilComm(t *testing.T) { } func TestProvisionHook_cancel(t *testing.T) { - var lock sync.Mutex - order := make([]string, 0, 2) + topCtx, topCtxCancel := context.WithCancel(context.Background()) p := &MockProvisioner{ - ProvFunc: func() error { - time.Sleep(100 * time.Millisecond) - - lock.Lock() - defer lock.Unlock() - order = append(order, "prov") - - return nil + ProvFunc: func(ctx context.Context) error { + topCtxCancel() + <-ctx.Done() + return ctx.Err() }, } @@ -83,27 +77,10 @@ func TestProvisionHook_cancel(t *testing.T) { {p, nil, ""}, }, } - ctx, cancel := context.WithCancel(context.Background()) - finished := make(chan struct{}) - go func() { - hook.Run(ctx, "foo", nil, new(MockCommunicator), nil) - close(finished) - }() - - // Cancel it while it is running - time.Sleep(10 * time.Millisecond) - cancel() - lock.Lock() - order = append(order, "cancel") - lock.Unlock() - - // Wait - <-finished - - // Verify order - if len(order) != 2 || order[0] != "cancel" || order[1] != "prov" { - t.Fatalf("bad: %#v", order) + err := hook.Run(topCtx, "foo", nil, new(MockCommunicator), nil) + if err == nil { + t.Fatal("should have err") } } @@ -156,7 +133,7 @@ func TestPausedProvisionerProvision_waits(t *testing.T) { } dataCh := make(chan struct{}) - mock.ProvFunc = func() error { + mock.ProvFunc = func(context.Context) error { close(dataCh) return nil } @@ -177,28 +154,22 @@ func TestPausedProvisionerProvision_waits(t *testing.T) { } func TestPausedProvisionerCancel(t *testing.T) { + topCtx, cancelTopCtx := context.WithCancel(context.Background()) + mock := new(MockProvisioner) prov := &PausedProvisioner{ Provisioner: mock, } - provCh := make(chan struct{}) - mock.ProvFunc = func() error { - close(provCh) - time.Sleep(10 * time.Millisecond) - return nil + mock.ProvFunc = func(ctx context.Context) error { + cancelTopCtx() + <-ctx.Done() + return ctx.Err() } - ctx, cancel := context.WithCancel(context.Background()) - // Start provisioning and wait for it to start - go func() { - <-provCh - cancel() - }() - - prov.Provision(ctx, testUi(), new(MockCommunicator)) - if !mock.CancelCalled { - t.Fatal("cancel should be called") + err := prov.Provision(topCtx, testUi(), new(MockCommunicator)) + if err == nil { + t.Fatal("should have err") } } @@ -243,27 +214,21 @@ func TestDebuggedProvisionerProvision(t *testing.T) { } func TestDebuggedProvisionerCancel(t *testing.T) { + topCtx, topCtxCancel := context.WithCancel(context.Background()) + mock := new(MockProvisioner) prov := &DebuggedProvisioner{ Provisioner: mock, } - provCh := make(chan struct{}) - mock.ProvFunc = func() error { - close(provCh) - time.Sleep(10 * time.Millisecond) - return nil + mock.ProvFunc = func(ctx context.Context) error { + topCtxCancel() + <-ctx.Done() + return ctx.Err() } - ctx, cancel := context.WithCancel(context.Background()) - // Start provisioning and wait for it to start - go func() { - <-provCh - cancel() - }() - - prov.Provision(ctx, testUi(), new(MockCommunicator)) - if !mock.CancelCalled { - t.Fatal("cancel should be called") + err := prov.Provision(topCtx, testUi(), new(MockCommunicator)) + if err == nil { + t.Fatal("should have error") } } diff --git a/packer/rpc/build.go b/packer/rpc/build.go index 916ea7138..39b31ed67 100644 --- a/packer/rpc/build.go +++ b/packer/rpc/build.go @@ -2,6 +2,7 @@ package rpc import ( "context" + "log" "net/rpc" "github.com/hashicorp/packer/packer" @@ -17,6 +18,9 @@ type build struct { // BuildServer wraps a packer.Build implementation and makes it exportable // as part of a Golang RPC server. type BuildServer struct { + context context.Context + contextCancel func() + build packer.Build mux *muxBroker } @@ -50,6 +54,19 @@ func (b *build) Run(ctx context.Context, ui packer.Ui) ([]packer.Artifact, error server.RegisterUi(ui) go server.Serve() + done := make(chan interface{}) + defer close(done) + go func() { + select { + case <-ctx.Done(): + log.Printf("Cancelling build after context cancellation %v", ctx.Err()) + if err := b.client.Call("Build.Cancel", new(interface{}), new(interface{})); err != nil { + log.Printf("Error cancelling builder: %s", err) + } + case <-done: + } + }() + var result []uint32 if err := b.client.Call("Build.Run", nextId, &result); err != nil { return nil, err @@ -106,14 +123,18 @@ func (b *BuildServer) Prepare(args *interface{}, resp *BuildPrepareResponse) err return nil } -func (b *BuildServer) Run(ctx context.Context, streamId uint32, reply *[]uint32) error { +func (b *BuildServer) Run(streamId uint32, reply *[]uint32) error { + if b.context == nil { + b.context, b.contextCancel = context.WithCancel(context.Background()) + } + client, err := newClientWithMux(b.mux, streamId) if err != nil { return NewBasicError(err) } defer client.Close() - artifacts, err := b.build.Run(ctx, client.Ui()) + artifacts, err := b.build.Run(b.context, client.Ui()) if err != nil { return NewBasicError(err) } @@ -147,6 +168,8 @@ func (b *BuildServer) SetOnError(val *string, reply *interface{}) error { } func (b *BuildServer) Cancel(args *interface{}, reply *interface{}) error { - panic("cancel !") + if b.contextCancel != nil { + b.contextCancel() + } return nil } diff --git a/packer/rpc/build_test.go b/packer/rpc/build_test.go index 4084c04ec..aa19313de 100644 --- a/packer/rpc/build_test.go +++ b/packer/rpc/build_test.go @@ -15,6 +15,7 @@ type testBuild struct { nameCalled bool prepareCalled bool prepareWarnings []string + runFn func(context.Context) runCalled bool runUi packer.Ui setDebugCalled bool @@ -36,13 +37,13 @@ func (b *testBuild) Prepare() ([]string, error) { } func (b *testBuild) Run(ctx context.Context, ui packer.Ui) ([]packer.Artifact, error) { - go func() { - <-ctx.Done() - b.cancelCalled = true - }() b.runCalled = true b.runUi = ui + if b.runFn != nil { + b.runFn(ctx) + } + if b.errRunResult { return nil, errors.New("foo") } else { @@ -62,10 +63,6 @@ func (b *testBuild) SetOnError(string) { b.setOnErrorCalled = true } -func (b *testBuild) Cancel() { - b.cancelCalled = true -} - func TestBuild(t *testing.T) { b := new(testBuild) client, server := testClientServer(t) @@ -74,7 +71,7 @@ func TestBuild(t *testing.T) { server.RegisterBuild(b) bClient := client.Build() - ctx, cancel := context.WithCancel(context.Background()) + ctx := context.Background() // Test Name bClient.Name() @@ -131,12 +128,33 @@ func TestBuild(t *testing.T) { if !b.setOnErrorCalled { t.Fatal("should be called") } +} - // Test Cancel - cancel() - if !b.cancelCalled { - t.Fatal("should be called") +func TestBuild_cancel(t *testing.T) { + topCtx, cancelTopCtx := context.WithCancel(context.Background()) + + b := new(testBuild) + + done := make(chan interface{}) + b.runFn = func(ctx context.Context) { + cancelTopCtx() + <-ctx.Done() + close(done) } + + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterBuild(b) + bClient := client.Build() + + bClient.Prepare() + + ui := new(testUi) + bClient.Run(topCtx, ui) + + // if context cancellation is not propagated, this will timeout + <-done } func TestBuildPrepare_Warnings(t *testing.T) { diff --git a/packer/rpc/builder.go b/packer/rpc/builder.go index b8ba27b35..4b95bea3e 100644 --- a/packer/rpc/builder.go +++ b/packer/rpc/builder.go @@ -2,6 +2,7 @@ package rpc import ( "context" + "log" "net/rpc" "github.com/hashicorp/packer/packer" @@ -17,6 +18,9 @@ type builder struct { // BuilderServer wraps a packer.Builder implementation and makes it exportable // as part of a Golang RPC server. type BuilderServer struct { + context context.Context + contextCancel func() + builder packer.Builder mux *muxBroker } @@ -51,7 +55,21 @@ func (b *builder) Run(ctx context.Context, ui packer.Ui, hook packer.Hook) (pack server.RegisterUi(ui) go server.Serve() + done := make(chan interface{}) + defer close(done) + go func() { + select { + case <-ctx.Done(): + log.Printf("Cancelling builder after context cancellation %v", ctx.Err()) + if err := b.client.Call("Builder.Cancel", new(interface{}), new(interface{})); err != nil { + log.Printf("Error cancelling builder: %s", err) + } + case <-done: + } + }() + var responseId uint32 + if err := b.client.Call("Builder.Run", nextId, &responseId); err != nil { return nil, err } @@ -77,14 +95,18 @@ func (b *BuilderServer) Prepare(args *BuilderPrepareArgs, reply *BuilderPrepareR return nil } -func (b *BuilderServer) Run(ctx context.Context, streamId uint32, reply *uint32) error { +func (b *BuilderServer) Run(streamId uint32, reply *uint32) error { client, err := newClientWithMux(b.mux, streamId) if err != nil { return NewBasicError(err) } defer client.Close() - artifact, err := b.builder.Run(ctx, client.Ui(), client.Hook()) + if b.context == nil { + b.context, b.contextCancel = context.WithCancel(context.Background()) + } + + artifact, err := b.builder.Run(b.context, client.Ui(), client.Hook()) if err != nil { return NewBasicError(err) } @@ -92,11 +114,16 @@ func (b *BuilderServer) Run(ctx context.Context, streamId uint32, reply *uint32) *reply = 0 if artifact != nil { streamId = b.mux.NextId() - server := newServerWithMux(b.mux, streamId) - server.RegisterArtifact(artifact) - go server.Serve() + artifactServer := newServerWithMux(b.mux, streamId) + artifactServer.RegisterArtifact(artifact) + go artifactServer.Serve() *reply = streamId } return nil } + +func (b *BuilderServer) Cancel(args *interface{}, reply *interface{}) error { + b.contextCancel() + return nil +} diff --git a/packer/rpc/builder_test.go b/packer/rpc/builder_test.go index bbea2bb79..9d89239b0 100644 --- a/packer/rpc/builder_test.go +++ b/packer/rpc/builder_test.go @@ -127,19 +127,26 @@ func TestBuilderRun_ErrResult(t *testing.T) { } func TestBuilderCancel(t *testing.T) { + topCtx, topCtxCancel := context.WithCancel(context.Background()) + // var runCtx context.Context + b := new(packer.MockBuilder) + cancelled := false + b.RunFn = func(ctx context.Context) { + topCtxCancel() + <-ctx.Done() + cancelled = true + } client, server := testClientServer(t) defer client.Close() defer server.Close() server.RegisterBuilder(b) bClient := client.Builder() - ctx, cancel := context.WithCancel(context.Background()) - cancel() - bClient.Run(ctx, nil, nil) + bClient.Run(topCtx, new(testUi), new(packer.MockHook)) - if !b.CancelCalled { - t.Fatal("cancel should be called") + if !cancelled { + t.Fatal("context should have been cancelled") } } diff --git a/packer/rpc/client_test.go b/packer/rpc/client_test.go index 44c972e66..d087886f7 100644 --- a/packer/rpc/client_test.go +++ b/packer/rpc/client_test.go @@ -35,7 +35,7 @@ func testConn(t *testing.T) (net.Conn, net.Conn) { func testClientServer(t *testing.T) (*Client, *Server) { clientConn, serverConn := testConn(t) - server := NewServer(serverConn) + server, _ := NewServer(serverConn) go server.Serve() client, err := NewClient(clientConn) diff --git a/packer/rpc/hook.go b/packer/rpc/hook.go index 1ca368e45..e4e804839 100644 --- a/packer/rpc/hook.go +++ b/packer/rpc/hook.go @@ -18,6 +18,9 @@ type hook struct { // HookServer wraps a packer.Hook implementation and makes it exportable // as part of a Golang RPC server. type HookServer struct { + context context.Context + contextCancel func() + hook packer.Hook mux *muxBroker } @@ -35,6 +38,19 @@ func (h *hook) Run(ctx context.Context, name string, ui packer.Ui, comm packer.C server.RegisterUi(ui) go server.Serve() + done := make(chan interface{}) + defer close(done) + go func() { + select { + case <-ctx.Done(): + log.Printf("Cancelling hook after context cancellation %v", ctx.Err()) + if err := h.client.Call("Hook.Cancel", new(interface{}), new(interface{})); err != nil { + log.Printf("Error cancelling builder: %s", err) + } + case <-done: + } + }() + args := HookRunArgs{ Name: name, Data: data, @@ -44,24 +60,27 @@ func (h *hook) Run(ctx context.Context, name string, ui packer.Ui, comm packer.C return h.client.Call("Hook.Run", &args, new(interface{})) } -func (h *hook) Cancel() { - err := h.client.Call("Hook.Cancel", new(interface{}), new(interface{})) - if err != nil { - log.Printf("Hook.Cancel error: %s", err) - } -} - -func (h *HookServer) Run(ctx context.Context, args *HookRunArgs, reply *interface{}) error { +func (h *HookServer) Run(args *HookRunArgs, reply *interface{}) error { client, err := newClientWithMux(h.mux, args.StreamId) if err != nil { return NewBasicError(err) } defer client.Close() - if err := h.hook.Run(ctx, args.Name, client.Ui(), client.Communicator(), args.Data); err != nil { + if h.context == nil { + h.context, h.contextCancel = context.WithCancel(context.Background()) + } + if err := h.hook.Run(h.context, args.Name, client.Ui(), client.Communicator(), args.Data); err != nil { return NewBasicError(err) } *reply = nil return nil } + +func (h *HookServer) Cancel(args *interface{}, reply *interface{}) error { + if h.contextCancel != nil { + h.contextCancel() + } + return nil +} diff --git a/packer/rpc/hook_test.go b/packer/rpc/hook_test.go index f4a0b8169..700feda65 100644 --- a/packer/rpc/hook_test.go +++ b/packer/rpc/hook_test.go @@ -2,59 +2,25 @@ package rpc import ( "context" - "reflect" - "sync" "testing" - "time" "github.com/hashicorp/packer/packer" ) -func TestHookRPC(t *testing.T) { - // Create the UI to test - h := new(packer.MockHook) - ctx, cancel := context.WithCancel(context.Background()) - - // Serve - client, server := testClientServer(t) - defer client.Close() - defer server.Close() - server.RegisterHook(h) - hClient := client.Hook() - - // Test Run - ui := &testUi{} - hClient.Run(ctx, "foo", ui, nil, 42) - if !h.RunCalled { - t.Fatal("should be called") - } - - // Test Cancel - cancel() - if !h.CancelCalled { - t.Fatal("should be called") - } -} - func TestHook_Implements(t *testing.T) { var _ packer.Hook = new(hook) } func TestHook_cancelWhileRun(t *testing.T) { - var finishLock sync.Mutex - finishOrder := make([]string, 0, 2) + topCtx, cancelTopCtx := context.WithCancel(context.Background()) h := &packer.MockHook{ - RunFunc: func() error { - time.Sleep(100 * time.Millisecond) - - finishLock.Lock() - finishOrder = append(finishOrder, "run") - finishLock.Unlock() - return nil + RunFunc: func(ctx context.Context) error { + cancelTopCtx() + <-ctx.Done() + return ctx.Err() }, } - ctx, cancel := context.WithCancel(context.Background()) // Serve client, server := testClientServer(t) @@ -64,26 +30,9 @@ func TestHook_cancelWhileRun(t *testing.T) { hClient := client.Hook() // Start the run - finished := make(chan struct{}) - go func() { - hClient.Run(ctx, "foo", nil, nil, nil) - close(finished) - }() + err := hClient.Run(topCtx, "foo", nil, nil, nil) - // Cancel it pretty quickly. - time.Sleep(10 * time.Millisecond) - cancel() - - finishLock.Lock() - finishOrder = append(finishOrder, "cancel") - finishLock.Unlock() - - // Verify things are good - <-finished - - // Check the results - expected := []string{"cancel", "run"} - if !reflect.DeepEqual(finishOrder, expected) { - t.Fatalf("bad: %#v", finishOrder) + if err == nil { + t.Fatal("should have errored") } } diff --git a/packer/rpc/post_processor.go b/packer/rpc/post_processor.go index 1d068ae50..5b8c5aed4 100644 --- a/packer/rpc/post_processor.go +++ b/packer/rpc/post_processor.go @@ -2,6 +2,7 @@ package rpc import ( "context" + "log" "net/rpc" "github.com/hashicorp/packer/packer" @@ -17,6 +18,9 @@ type postProcessor struct { // PostProcessorServer wraps a packer.PostProcessor implementation and makes it // exportable as part of a Golang RPC server. type PostProcessorServer struct { + context context.Context + contextCancel func() + mux *muxBroker p packer.PostProcessor } @@ -47,6 +51,20 @@ func (p *postProcessor) PostProcess(ctx context.Context, ui packer.Ui, a packer. server.RegisterUi(ui) go server.Serve() + done := make(chan interface{}) + defer close(done) + + go func() { + select { + case <-ctx.Done(): + log.Printf("Cancelling post-processor after context cancellation %v", ctx.Err()) + if err := p.client.Call("PostProcessor.Cancel", new(interface{}), new(interface{})); err != nil { + log.Printf("Error cancelling post-processor: %s", err) + } + case <-done: + } + }() + var response PostProcessorProcessResponse if err := p.client.Call("PostProcessor.PostProcess", nextId, &response); err != nil { return nil, false, err @@ -73,15 +91,19 @@ func (p *PostProcessorServer) Configure(args *PostProcessorConfigureArgs, reply return err } -func (p *PostProcessorServer) PostProcess(ctx context.Context, streamId uint32, reply *PostProcessorProcessResponse) error { +func (p *PostProcessorServer) PostProcess(streamId uint32, reply *PostProcessorProcessResponse) error { client, err := newClientWithMux(p.mux, streamId) if err != nil { return NewBasicError(err) } defer client.Close() + if p.context == nil { + p.context, p.contextCancel = context.WithCancel(context.Background()) + } + streamId = 0 - artifactResult, keep, err := p.p.PostProcess(ctx, client.Ui(), client.Artifact()) + artifactResult, keep, err := p.p.PostProcess(p.context, client.Ui(), client.Artifact()) if err == nil && artifactResult != nil { streamId = p.mux.NextId() server := newServerWithMux(p.mux, streamId) @@ -97,3 +119,10 @@ func (p *PostProcessorServer) PostProcess(ctx context.Context, streamId uint32, return nil } + +func (b *PostProcessorServer) Cancel(args *interface{}, reply *interface{}) error { + if b.contextCancel != nil { + b.contextCancel() + } + return nil +} diff --git a/packer/rpc/post_processor_test.go b/packer/rpc/post_processor_test.go index 9858608c4..9d78ad02b 100644 --- a/packer/rpc/post_processor_test.go +++ b/packer/rpc/post_processor_test.go @@ -17,6 +17,8 @@ type TestPostProcessor struct { ppArtifact packer.Artifact ppArtifactId string ppUi packer.Ui + + postProcessFn func(context.Context) error } func (pp *TestPostProcessor) Configure(v ...interface{}) error { @@ -30,6 +32,9 @@ func (pp *TestPostProcessor) PostProcess(ctx context.Context, ui packer.Ui, a pa pp.ppArtifact = a pp.ppArtifactId = a.Id() pp.ppUi = ui + if pp.postProcessFn != nil { + return testPostProcessorArtifact, false, pp.postProcessFn(ctx) + } return testPostProcessorArtifact, false, nil } @@ -84,6 +89,41 @@ func TestPostProcessorRPC(t *testing.T) { } } +func TestPostProcessorRPC_cancel(t *testing.T) { + topCtx, cancelTopCtx := context.WithCancel(context.Background()) + + p := new(TestPostProcessor) + p.postProcessFn = func(ctx context.Context) error { + cancelTopCtx() + <-ctx.Done() + return ctx.Err() + } + + // Start the server + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + if err := server.RegisterPostProcessor(p); err != nil { + panic(err) + } + + ppClient := client.PostProcessor() + + // Test Configure + config := 42 + err := ppClient.Configure(config) + + // Test PostProcess + a := &packer.MockArtifact{ + IdValue: "ppTestId", + } + ui := new(testUi) + _, _, err = ppClient.PostProcess(topCtx, ui, a) + if err == nil { + t.Fatalf("should err") + } +} + func TestPostProcessor_Implements(t *testing.T) { var raw interface{} raw = new(postProcessor) diff --git a/packer/rpc/provisioner.go b/packer/rpc/provisioner.go index 1a66c6057..faad6f434 100644 --- a/packer/rpc/provisioner.go +++ b/packer/rpc/provisioner.go @@ -2,6 +2,7 @@ package rpc import ( "context" + "log" "net/rpc" "github.com/hashicorp/packer/packer" @@ -17,6 +18,9 @@ type provisioner struct { // ProvisionerServer wraps a packer.Provisioner implementation and makes it // exportable as part of a Golang RPC server. type ProvisionerServer struct { + context context.Context + contextCancel func() + p packer.Provisioner mux *muxBroker } @@ -41,23 +45,46 @@ func (p *provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.C server.RegisterUi(ui) go server.Serve() + done := make(chan interface{}) + defer close(done) + + go func() { + select { + case <-ctx.Done(): + log.Printf("Cancelling provisioner after context cancellation %v", ctx.Err()) + if err := p.client.Call("Provisioner.Cancel", new(interface{}), new(interface{})); err != nil { + log.Printf("Error cancelling provisioner: %s", err) + } + case <-done: + } + }() + return p.client.Call("Provisioner.Provision", nextId, new(interface{})) } -func (p *ProvisionerServer) Prepare(_ context.Context, args *ProvisionerPrepareArgs, reply *interface{}) error { +func (p *ProvisionerServer) Prepare(args *ProvisionerPrepareArgs, reply *interface{}) error { return p.p.Prepare(args.Configs...) } -func (p *ProvisionerServer) Provision(ctx context.Context, streamId uint32, reply *interface{}) error { +func (p *ProvisionerServer) Provision(streamId uint32, reply *interface{}) error { client, err := newClientWithMux(p.mux, streamId) if err != nil { return NewBasicError(err) } defer client.Close() - if err := p.p.Provision(ctx, client.Ui(), client.Communicator()); err != nil { + if p.context == nil { + p.context, p.contextCancel = context.WithCancel(context.Background()) + } + + if err := p.p.Provision(p.context, client.Ui(), client.Communicator()); err != nil { return NewBasicError(err) } return nil } + +func (p *ProvisionerServer) Cancel(args *interface{}, reply *interface{}) error { + p.contextCancel() + return nil +} diff --git a/packer/rpc/provisioner_test.go b/packer/rpc/provisioner_test.go index 1904bbd53..bbcbc88d8 100644 --- a/packer/rpc/provisioner_test.go +++ b/packer/rpc/provisioner_test.go @@ -9,8 +9,15 @@ import ( ) func TestProvisionerRPC(t *testing.T) { + topCtx, topCtxCancel := context.WithCancel(context.Background()) + // Create the interface to test p := new(packer.MockProvisioner) + p.ProvFunc = func(ctx context.Context) error { + topCtxCancel() + <-ctx.Done() + return ctx.Err() + } // Start the server client, server := testClientServer(t) @@ -18,7 +25,6 @@ func TestProvisionerRPC(t *testing.T) { defer server.Close() server.RegisterProvisioner(p) pClient := client.Provisioner() - ctx, cancel := context.WithCancel(context.Background()) // Test Prepare config := 42 pClient.Prepare(config) @@ -33,18 +39,13 @@ func TestProvisionerRPC(t *testing.T) { // Test Provision ui := &testUi{} comm := &packer.MockCommunicator{} - if err := pClient.Provision(ctx, ui, comm); err != nil { - t.Fatalf("err: %v", err) + if err := pClient.Provision(topCtx, ui, comm); err == nil { + t.Fatalf("Provison should have err") } if !p.ProvCalled { t.Fatal("should be called") } - // Test Cancel - cancel() - if !p.CancelCalled { - t.Fatal("cancel should be called") - } } func TestProvisioner_Implements(t *testing.T) { diff --git a/packer/rpc/server.go b/packer/rpc/server.go index c02f75409..e0b90d261 100644 --- a/packer/rpc/server.go +++ b/packer/rpc/server.go @@ -32,12 +32,15 @@ type Server struct { } // NewServer returns a new Packer RPC server. -func NewServer(conn io.ReadWriteCloser) *Server { - mux, _ := newMuxBrokerServer(conn) +func NewServer(conn io.ReadWriteCloser) (*Server, error) { + mux, err := newMuxBrokerServer(conn) + if err != nil { + return nil, err + } result := newServerWithMux(mux, 0) result.closeMux = true go mux.Run() - return result + return result, nil } func newServerWithMux(mux *muxBroker, streamId uint32) *Server { diff --git a/provisioner/windows-restart/provisioner_test.go b/provisioner/windows-restart/provisioner_test.go index b6039b1b8..035f5e246 100644 --- a/provisioner/windows-restart/provisioner_test.go +++ b/provisioner/windows-restart/provisioner_test.go @@ -342,40 +342,27 @@ func TestProvision_Cancel(t *testing.T) { ui := testUi() p := new(Provisioner) - var err error - comm := new(packer.MockCommunicator) p.Prepare(config) - waitStart := make(chan bool) - waitDone := make(chan bool) + done := make(chan error) + + topCtx, cancelTopCtx := context.WithCancel(context.Background()) // Block until cancel comes through waitForCommunicator = func(ctx context.Context, p *Provisioner) error { - waitStart <- true - panic("this test is incorrect") - for { - select { - case <-p.cancel: - } - } + cancelTopCtx() + <-ctx.Done() + return ctx.Err() } - ctx, cancel := context.WithCancel(context.Background()) // Create two go routines to provision and cancel in parallel // Provision will block until cancel happens go func() { - err = p.Provision(ctx, ui, comm) - waitDone <- true + done <- p.Provision(topCtx, ui, comm) }() - go func() { - <-waitStart - cancel() - }() - <-waitDone - // Expect interrupt error - if err == nil { + if err := <-done; err == nil { t.Fatal("should have error") } }