diff --git a/internal/backend/backend.go b/internal/backend/backend.go index 7e75b32910..edba1a20f3 100644 --- a/internal/backend/backend.go +++ b/internal/backend/backend.go @@ -170,7 +170,13 @@ type Local interface { // backend's implementations of this to understand what this actually // does, because this operation has no well-defined contract aside from // "whatever it already does". - LocalRun(context.Context, *Operation) (*LocalRun, statemgr.Full, tfdiags.Diagnostics) + // + // Even though both contexts contain the tracing information, there is a crucial difference between the two: + // - The first one is non-cancellable meaning that it's safe to be used for thing like state unlocking and other + // critical operations that need to run even when the process is asked to end gracefully. + // - The second one is a cancellable context. This is the context that should be used to cancel operations + // for a graceful shutdown. + LocalRun(context.Context, context.Context, *Operation) (*LocalRun, statemgr.Full, tfdiags.Diagnostics) } // LocalRun represents the assortment of objects that we can collect or diff --git a/internal/backend/local/backend_apply.go b/internal/backend/local/backend_apply.go index fe8ecc2b17..b5cb9b8f53 100644 --- a/internal/backend/local/backend_apply.go +++ b/internal/backend/local/backend_apply.go @@ -88,7 +88,7 @@ func (b *Local) opApply( op.Hooks = append(op.Hooks, stateHook) // Get our context - lr, _, opState, contextDiags := b.localRun(ctx, op) + lr, _, opState, contextDiags := b.localRun(ctx, stopCtx, op) diags = diags.Append(contextDiags) if contextDiags.HasErrors() { op.ReportResult(runningOp, diags) diff --git a/internal/backend/local/backend_local.go b/internal/backend/local/backend_local.go index fc35d96955..d8929f369b 100644 --- a/internal/backend/local/backend_local.go +++ b/internal/backend/local/backend_local.go @@ -28,7 +28,7 @@ import ( var _ backend.Local = (*Local)(nil) // backend.Local implementation. -func (b *Local) LocalRun(ctx context.Context, op *backend.Operation) (*backend.LocalRun, statemgr.Full, tfdiags.Diagnostics) { +func (b *Local) LocalRun(ctx context.Context, stopCtx context.Context, op *backend.Operation) (*backend.LocalRun, statemgr.Full, tfdiags.Diagnostics) { // Make sure the type is invalid. We use this as a way to know not // to ask for input/validate. We're modifying this through a pointer, // so we're mutating an object that belongs to the caller here, which @@ -39,11 +39,11 @@ func (b *Local) LocalRun(ctx context.Context, op *backend.Operation) (*backend.L op.StateLocker = op.StateLocker.WithContext(context.Background()) - lr, _, stateMgr, diags := b.localRun(ctx, op) + lr, _, stateMgr, diags := b.localRun(ctx, stopCtx, op) return lr, stateMgr, diags } -func (b *Local) localRun(ctx context.Context, op *backend.Operation) (*backend.LocalRun, *configload.Snapshot, statemgr.Full, tfdiags.Diagnostics) { +func (b *Local) localRun(ctx context.Context, stopCtx context.Context, op *backend.Operation) (*backend.LocalRun, *configload.Snapshot, statemgr.Full, tfdiags.Diagnostics) { var diags tfdiags.Diagnostics // Get the latest state. @@ -111,7 +111,7 @@ func (b *Local) localRun(ctx context.Context, op *backend.Operation) (*backend.L op.ConfigLoader.ImportSourcesFromSnapshot(configSnap) } else { log.Printf("[TRACE] backend/local: populating backend.LocalRun for current working directory") - ret, configSnap, ctxDiags = b.localRunDirect(ctx, op, ret, &coreOpts, s) + ret, configSnap, ctxDiags = b.localRunDirect(ctx, stopCtx, op, ret, &coreOpts, s) } diags = diags.Append(ctxDiags) if diags.HasErrors() { @@ -144,7 +144,7 @@ func (b *Local) localRun(ctx context.Context, op *backend.Operation) (*backend.L return ret, configSnap, s, diags } -func (b *Local) localRunDirect(ctx context.Context, op *backend.Operation, run *backend.LocalRun, coreOpts *tofu.ContextOpts, s statemgr.Full) (*backend.LocalRun, *configload.Snapshot, tfdiags.Diagnostics) { +func (b *Local) localRunDirect(ctx context.Context, stopCtx context.Context, op *backend.Operation, run *backend.LocalRun, coreOpts *tofu.ContextOpts, s statemgr.Full) (*backend.LocalRun, *configload.Snapshot, tfdiags.Diagnostics) { var diags tfdiags.Diagnostics // Load the configuration using the caller-provided configuration loader. @@ -191,9 +191,7 @@ func (b *Local) localRunDirect(ctx context.Context, op *backend.Operation, run * } else { // If interactive input is enabled, we might gather some more variable // values through interactive prompts. - // TODO: Need to route the operation context through into here, so that - // the interactive prompts can be sensitive to its timeouts/etc. - rawVariables = b.interactiveCollectVariables(ctx, op.Variables, config.Module.Variables, op.UIIn) + rawVariables = b.interactiveCollectVariables(stopCtx, op.Variables, config.Module.Variables, op.UIIn) } variables, varDiags := backend.ParseVariableValues(rawVariables, config.Module.Variables) diff --git a/internal/backend/local/backend_local_test.go b/internal/backend/local/backend_local_test.go index 5b25e9f930..14cb782102 100644 --- a/internal/backend/local/backend_local_test.go +++ b/internal/backend/local/backend_local_test.go @@ -10,8 +10,11 @@ import ( "fmt" "os" "path/filepath" + "strings" "testing" + "time" + "github.com/opentofu/opentofu/internal/providers" "github.com/zclconf/go-cty/cty" "github.com/opentofu/opentofu/internal/backend" @@ -49,7 +52,7 @@ func TestLocalRun(t *testing.T) { StateLocker: stateLocker, } - _, _, diags := b.LocalRun(context.Background(), op) + _, _, diags := b.LocalRun(context.Background(), t.Context(), op) if diags.HasErrors() { t.Fatalf("unexpected error: %s", diags.Err().Error()) } @@ -58,6 +61,41 @@ func TestLocalRun(t *testing.T) { assertBackendStateLocked(t, b) } +func TestLocalRun_ErrorWhenUiInputIsCancelled(t *testing.T) { + b := TestLocal(t) + + p := TestLocalProvider(t, b, "test", applyFixtureSchema()) + p.ApplyResourceChangeResponse = &providers.ApplyResourceChangeResponse{NewState: cty.ObjectVal(map[string]cty.Value{ + "id": cty.StringVal("yes"), + "ami": cty.StringVal("bar"), + })} + + op, done := testOperationApply(t, "./testdata/apply-with-vars") + + run, err := b.Operation(context.Background(), op) + if err != nil { + t.Fatalf("bad: %s", err) + } + go func() { + <-time.After(1 * time.Second) + run.Stop() + }() + + select { + case <-run.Done(): + case <-time.After(5 * time.Second): + t.Fatalf("hit the timeout. expected for the operation to finish before the timeout") + } + if run.Result != backend.OperationFailure { + t.Fatal("operation suceeded but expected to fail") + } + + expectedErrHeader := "Error: No value for required variable" + if errOutput := done(t).Stderr(); !strings.Contains(errOutput, expectedErrHeader) { + t.Fatalf("unexpected error output. Expected to contain %q but it does not:\n%s", expectedErrHeader, errOutput) + } +} + func TestLocalRun_error(t *testing.T) { configDir := "./testdata/invalid" b := TestLocal(t) @@ -80,7 +118,7 @@ func TestLocalRun_error(t *testing.T) { StateLocker: stateLocker, } - _, _, diags := b.LocalRun(context.Background(), op) + _, _, diags := b.LocalRun(context.Background(), t.Context(), op) if !diags.HasErrors() { t.Fatal("unexpected success") } @@ -115,7 +153,7 @@ func TestLocalRun_cloudPlan(t *testing.T) { StateLocker: stateLocker, } - _, _, diags := b.LocalRun(context.Background(), op) + _, _, diags := b.LocalRun(context.Background(), t.Context(), op) if !diags.HasErrors() { t.Fatal("unexpected success") } @@ -201,7 +239,7 @@ func TestLocalRun_stalePlan(t *testing.T) { StateLocker: stateLocker, } - _, _, diags := b.LocalRun(context.Background(), op) + _, _, diags := b.LocalRun(context.Background(), t.Context(), op) if !diags.HasErrors() { t.Fatal("unexpected success") } diff --git a/internal/backend/local/backend_plan.go b/internal/backend/local/backend_plan.go index 13a45924d9..1ae8f9d48a 100644 --- a/internal/backend/local/backend_plan.go +++ b/internal/backend/local/backend_plan.go @@ -88,7 +88,7 @@ func (b *Local) opPlan( } // Get our context - lr, configSnap, opState, ctxDiags := b.localRun(ctx, op) + lr, configSnap, opState, ctxDiags := b.localRun(ctx, stopCtx, op) diags = diags.Append(ctxDiags) if ctxDiags.HasErrors() { op.ReportResult(runningOp, diags) diff --git a/internal/backend/local/backend_refresh.go b/internal/backend/local/backend_refresh.go index 786b5ef9c7..9fe102f023 100644 --- a/internal/backend/local/backend_refresh.go +++ b/internal/backend/local/backend_refresh.go @@ -59,7 +59,7 @@ func (b *Local) opRefresh( op.PlanRefresh = true // Get our context - lr, _, opState, contextDiags := b.localRun(ctx, op) + lr, _, opState, contextDiags := b.localRun(ctx, stopCtx, op) diags = diags.Append(contextDiags) if contextDiags.HasErrors() { op.ReportResult(runningOp, diags) diff --git a/internal/backend/remote/backend_context.go b/internal/backend/remote/backend_context.go index 9ed8afbbd5..65f592a910 100644 --- a/internal/backend/remote/backend_context.go +++ b/internal/backend/remote/backend_context.go @@ -27,7 +27,8 @@ import ( var _ backend.Local = (*Remote)(nil) // LocalRun implements backend.Local. -func (b *Remote) LocalRun(ctx context.Context, op *backend.Operation) (*backend.LocalRun, statemgr.Full, tfdiags.Diagnostics) { +// Refer to the comments of backend.Local for more details about ctx vs stopCtx. +func (b *Remote) LocalRun(ctx context.Context, stopCtx context.Context, op *backend.Operation) (*backend.LocalRun, statemgr.Full, tfdiags.Diagnostics) { var diags tfdiags.Diagnostics ret := &backend.LocalRun{ PlanOpts: &tofu.PlanOpts{ diff --git a/internal/backend/remote/backend_context_test.go b/internal/backend/remote/backend_context_test.go index 0bc8bc445c..7faa52e31a 100644 --- a/internal/backend/remote/backend_context_test.go +++ b/internal/backend/remote/backend_context_test.go @@ -215,7 +215,7 @@ func TestRemoteContextWithVars(t *testing.T) { t.Fatal(err) } - _, _, diags := b.LocalRun(t.Context(), op) + _, _, diags := b.LocalRun(t.Context(), t.Context(), op) if test.WantError != "" { if !diags.HasErrors() { @@ -439,7 +439,7 @@ func TestRemoteVariablesDoNotOverride(t *testing.T) { } } - lr, _, diags := b.LocalRun(t.Context(), op) + lr, _, diags := b.LocalRun(t.Context(), t.Context(), op) if diags.HasErrors() { t.Fatalf("unexpected error\ngot: %s\nwant: ", diags.Err().Error()) diff --git a/internal/cloud/backend_context.go b/internal/cloud/backend_context.go index 82c365afbf..2a2b9ee53d 100644 --- a/internal/cloud/backend_context.go +++ b/internal/cloud/backend_context.go @@ -27,7 +27,8 @@ import ( var _ backend.Local = (*Cloud)(nil) // LocalRun implements backend.Local -func (b *Cloud) LocalRun(ctx context.Context, op *backend.Operation) (*backend.LocalRun, statemgr.Full, tfdiags.Diagnostics) { +// Refer to the comments of backend.Local for more details about ctx vs stopCtx. +func (b *Cloud) LocalRun(ctx context.Context, stopCtx context.Context, op *backend.Operation) (*backend.LocalRun, statemgr.Full, tfdiags.Diagnostics) { var diags tfdiags.Diagnostics ret := &backend.LocalRun{ PlanOpts: &tofu.PlanOpts{ diff --git a/internal/cloud/backend_context_test.go b/internal/cloud/backend_context_test.go index dea434f414..6c9e0ad3e8 100644 --- a/internal/cloud/backend_context_test.go +++ b/internal/cloud/backend_context_test.go @@ -214,7 +214,7 @@ func TestRemoteContextWithVars(t *testing.T) { t.Fatal(err) } - _, _, diags := b.LocalRun(t.Context(), op) + _, _, diags := b.LocalRun(t.Context(), t.Context(), op) if test.WantError != "" { if !diags.HasErrors() { @@ -438,7 +438,7 @@ func TestRemoteVariablesDoNotOverride(t *testing.T) { } } - lr, _, diags := b.LocalRun(t.Context(), op) + lr, _, diags := b.LocalRun(t.Context(), t.Context(), op) if diags.HasErrors() { t.Fatalf("unexpected error\ngot: %s\nwant: ", diags.Err().Error()) diff --git a/internal/command/console.go b/internal/command/console.go index 81a1a97c82..2fc4c66f84 100644 --- a/internal/command/console.go +++ b/internal/command/console.go @@ -142,7 +142,9 @@ func (c *ConsoleCommand) Run(rawArgs []string) int { } // Get the context - lr, _, ctxDiags := local.LocalRun(ctx, opReq) + stopCtx, cancel := c.InterruptibleContext(ctx) + defer cancel() + lr, _, ctxDiags := local.LocalRun(ctx, stopCtx, opReq) diags = diags.Append(ctxDiags) if ctxDiags.HasErrors() { view.Diagnostics(diags) diff --git a/internal/command/graph.go b/internal/command/graph.go index 3b73200a4f..91a7e0c200 100644 --- a/internal/command/graph.go +++ b/internal/command/graph.go @@ -171,7 +171,9 @@ func (c *GraphCommand) Run(rawArgs []string) int { } // Get the context - lr, _, ctxDiags := local.LocalRun(ctx, opReq) + stopCtx, cancel := c.InterruptibleContext(ctx) + defer cancel() + lr, _, ctxDiags := local.LocalRun(ctx, stopCtx, opReq) diags = diags.Append(ctxDiags) if ctxDiags.HasErrors() { view.Diagnostics(diags) diff --git a/internal/command/import.go b/internal/command/import.go index e781710960..515b115851 100644 --- a/internal/command/import.go +++ b/internal/command/import.go @@ -238,7 +238,9 @@ func (c *ImportCommand) Run(rawArgs []string) int { } // Get the context - lr, state, ctxDiags := local.LocalRun(ctx, opReq) + stopCtx, cancel := c.InterruptibleContext(ctx) + defer cancel() + lr, state, ctxDiags := local.LocalRun(ctx, stopCtx, opReq) diags = diags.Append(ctxDiags) if ctxDiags.HasErrors() { view.Diagnostics(diags) diff --git a/internal/command/providers_schema.go b/internal/command/providers_schema.go index 146c30c41a..f252fe7c9a 100644 --- a/internal/command/providers_schema.go +++ b/internal/command/providers_schema.go @@ -116,7 +116,9 @@ func (c *ProvidersSchemaCommand) Run(rawArgs []string) int { } // Get the context - lr, _, ctxDiags := local.LocalRun(ctx, opReq) + stopCtx, cancel := c.InterruptibleContext(ctx) + defer cancel() + lr, _, ctxDiags := local.LocalRun(ctx, stopCtx, opReq) diags = diags.Append(ctxDiags) if ctxDiags.HasErrors() { view.Diagnostics(diags) diff --git a/internal/command/state_show.go b/internal/command/state_show.go index f7253e26e5..af8f391916 100644 --- a/internal/command/state_show.go +++ b/internal/command/state_show.go @@ -123,7 +123,9 @@ func (c *StateShowCommand) Run(rawArgs []string) int { } // Get the context (required to get the schemas) - lr, _, ctxDiags := local.LocalRun(ctx, opReq) + stopCtx, cancel := c.InterruptibleContext(ctx) + defer cancel() + lr, _, ctxDiags := local.LocalRun(ctx, stopCtx, opReq) if ctxDiags.HasErrors() { view.Diagnostics(ctxDiags) return 1 diff --git a/internal/command/ui_input.go b/internal/command/ui_input.go index 417817895b..dcf3dbbd6d 100644 --- a/internal/command/ui_input.go +++ b/internal/command/ui_input.go @@ -14,7 +14,6 @@ import ( "io" "log" "os" - "os/signal" "strings" "sync" "sync/atomic" @@ -46,9 +45,8 @@ type UIInput struct { result chan string err chan string - interrupted bool - l sync.Mutex - once sync.Once + l sync.Mutex + once sync.Once } func (i *UIInput) Input(ctx context.Context, opts *tofu.InputOpts) (string, error) { @@ -74,11 +72,6 @@ func (i *UIInput) Input(ctx context.Context, opts *tofu.InputOpts) (string, erro i.l.Lock() defer i.l.Unlock() - // If we're interrupted, then don't ask for input - if i.interrupted { - return "", errors.New("interrupted") - } - // If we have test results, return those. testInputResponse is the // "old" way of doing it and we should remove that. if testInputResponse != nil { @@ -101,11 +94,6 @@ func (i *UIInput) Input(ctx context.Context, opts *tofu.InputOpts) (string, erro log.Printf("[DEBUG] command: asking for input: %q", opts.Query) - // Listen for interrupts so we can cancel the input ask - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, os.Interrupt) - defer signal.Stop(sigCh) - // Build the output format for asking var buf bytes.Buffer buf.WriteString("[reset]") @@ -170,16 +158,7 @@ func (i *UIInput) Input(ctx context.Context, opts *tofu.InputOpts) (string, erro // on a new line. fmt.Fprintln(w) - return "", ctx.Err() - case <-sigCh: - // Print a newline so that any further output starts properly - // on a new line. - fmt.Fprintln(w) - - // Mark that we were interrupted so future Ask calls fail. - i.interrupted = true - - return "", errors.New("interrupted") + return "", fmt.Errorf("interrupted: %w", ctx.Err()) } } diff --git a/internal/command/ui_input_test.go b/internal/command/ui_input_test.go index 0823f043df..02ffd60dd9 100644 --- a/internal/command/ui_input_test.go +++ b/internal/command/ui_input_test.go @@ -8,6 +8,7 @@ package command import ( "bytes" "context" + "errors" "fmt" "io" "sync/atomic" @@ -55,7 +56,7 @@ func TestUIInputInput_canceled(t *testing.T) { // Get input until the context is canceled. v, err := i.Input(ctx, &tofu.InputOpts{}) - if err != context.Canceled { + if !errors.Is(err, context.Canceled) { t.Fatalf("expected a context.Canceled error, got: %v", err) } @@ -70,20 +71,31 @@ func TestUIInputInput_canceled(t *testing.T) { t.Fatalf("expected listening to be 1, got: %d", listening) } - go func() { - // Fake input is given after 1 second. - time.Sleep(time.Second) - fmt.Fprint(w, "foo\n") - w.Close() - }() - - v, err = i.Input(context.Background(), &tofu.InputOpts{}) - if err != nil { - t.Fatalf("unexpected error: %v", err) + // Using the same context that was cancelled should fail with the same error again when invoked again + { + _, err = i.Input(ctx, &tofu.InputOpts{}) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected a context.Canceled error, got: %v", err) + } } - if v != "foo" { - t.Fatalf("unexpected input: %s", v) + { + // But asking for input with a new, uncancelled context, should work just fine + go func() { + // Fake input is given after 1 second. + time.Sleep(time.Second) + fmt.Fprint(w, "foo\n") + w.Close() + }() + + v, err = i.Input(context.Background(), &tofu.InputOpts{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if v != "foo" { + t.Fatalf("unexpected input: %s", v) + } } }