UIInput must return when the first SIGINT/SIGTERM signal is received (#4051)

Signed-off-by: Andrei Ciobanu <andrei.ciobanu@opentofu.org>
This commit is contained in:
Andrei Ciobanu 2026-04-27 17:03:57 +03:00 committed by GitHub
parent c6f06f8f6d
commit 0af2e8d521
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 109 additions and 64 deletions

View file

@ -170,7 +170,13 @@ type Local interface {
// backend's implementations of this to understand what this actually
// does, because this operation has no well-defined contract aside from
// "whatever it already does".
LocalRun(context.Context, *Operation) (*LocalRun, statemgr.Full, tfdiags.Diagnostics)
//
// Even though both contexts contain the tracing information, there is a crucial difference between the two:
// - The first one is non-cancellable meaning that it's safe to be used for thing like state unlocking and other
// critical operations that need to run even when the process is asked to end gracefully.
// - The second one is a cancellable context. This is the context that should be used to cancel operations
// for a graceful shutdown.
LocalRun(context.Context, context.Context, *Operation) (*LocalRun, statemgr.Full, tfdiags.Diagnostics)
}
// LocalRun represents the assortment of objects that we can collect or

View file

@ -88,7 +88,7 @@ func (b *Local) opApply(
op.Hooks = append(op.Hooks, stateHook)
// Get our context
lr, _, opState, contextDiags := b.localRun(ctx, op)
lr, _, opState, contextDiags := b.localRun(ctx, stopCtx, op)
diags = diags.Append(contextDiags)
if contextDiags.HasErrors() {
op.ReportResult(runningOp, diags)

View file

@ -28,7 +28,7 @@ import (
var _ backend.Local = (*Local)(nil)
// backend.Local implementation.
func (b *Local) LocalRun(ctx context.Context, op *backend.Operation) (*backend.LocalRun, statemgr.Full, tfdiags.Diagnostics) {
func (b *Local) LocalRun(ctx context.Context, stopCtx context.Context, op *backend.Operation) (*backend.LocalRun, statemgr.Full, tfdiags.Diagnostics) {
// Make sure the type is invalid. We use this as a way to know not
// to ask for input/validate. We're modifying this through a pointer,
// so we're mutating an object that belongs to the caller here, which
@ -39,11 +39,11 @@ func (b *Local) LocalRun(ctx context.Context, op *backend.Operation) (*backend.L
op.StateLocker = op.StateLocker.WithContext(context.Background())
lr, _, stateMgr, diags := b.localRun(ctx, op)
lr, _, stateMgr, diags := b.localRun(ctx, stopCtx, op)
return lr, stateMgr, diags
}
func (b *Local) localRun(ctx context.Context, op *backend.Operation) (*backend.LocalRun, *configload.Snapshot, statemgr.Full, tfdiags.Diagnostics) {
func (b *Local) localRun(ctx context.Context, stopCtx context.Context, op *backend.Operation) (*backend.LocalRun, *configload.Snapshot, statemgr.Full, tfdiags.Diagnostics) {
var diags tfdiags.Diagnostics
// Get the latest state.
@ -111,7 +111,7 @@ func (b *Local) localRun(ctx context.Context, op *backend.Operation) (*backend.L
op.ConfigLoader.ImportSourcesFromSnapshot(configSnap)
} else {
log.Printf("[TRACE] backend/local: populating backend.LocalRun for current working directory")
ret, configSnap, ctxDiags = b.localRunDirect(ctx, op, ret, &coreOpts, s)
ret, configSnap, ctxDiags = b.localRunDirect(ctx, stopCtx, op, ret, &coreOpts, s)
}
diags = diags.Append(ctxDiags)
if diags.HasErrors() {
@ -144,7 +144,7 @@ func (b *Local) localRun(ctx context.Context, op *backend.Operation) (*backend.L
return ret, configSnap, s, diags
}
func (b *Local) localRunDirect(ctx context.Context, op *backend.Operation, run *backend.LocalRun, coreOpts *tofu.ContextOpts, s statemgr.Full) (*backend.LocalRun, *configload.Snapshot, tfdiags.Diagnostics) {
func (b *Local) localRunDirect(ctx context.Context, stopCtx context.Context, op *backend.Operation, run *backend.LocalRun, coreOpts *tofu.ContextOpts, s statemgr.Full) (*backend.LocalRun, *configload.Snapshot, tfdiags.Diagnostics) {
var diags tfdiags.Diagnostics
// Load the configuration using the caller-provided configuration loader.
@ -191,9 +191,7 @@ func (b *Local) localRunDirect(ctx context.Context, op *backend.Operation, run *
} else {
// If interactive input is enabled, we might gather some more variable
// values through interactive prompts.
// TODO: Need to route the operation context through into here, so that
// the interactive prompts can be sensitive to its timeouts/etc.
rawVariables = b.interactiveCollectVariables(ctx, op.Variables, config.Module.Variables, op.UIIn)
rawVariables = b.interactiveCollectVariables(stopCtx, op.Variables, config.Module.Variables, op.UIIn)
}
variables, varDiags := backend.ParseVariableValues(rawVariables, config.Module.Variables)

View file

@ -10,8 +10,11 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/opentofu/opentofu/internal/providers"
"github.com/zclconf/go-cty/cty"
"github.com/opentofu/opentofu/internal/backend"
@ -49,7 +52,7 @@ func TestLocalRun(t *testing.T) {
StateLocker: stateLocker,
}
_, _, diags := b.LocalRun(context.Background(), op)
_, _, diags := b.LocalRun(context.Background(), t.Context(), op)
if diags.HasErrors() {
t.Fatalf("unexpected error: %s", diags.Err().Error())
}
@ -58,6 +61,41 @@ func TestLocalRun(t *testing.T) {
assertBackendStateLocked(t, b)
}
func TestLocalRun_ErrorWhenUiInputIsCancelled(t *testing.T) {
b := TestLocal(t)
p := TestLocalProvider(t, b, "test", applyFixtureSchema())
p.ApplyResourceChangeResponse = &providers.ApplyResourceChangeResponse{NewState: cty.ObjectVal(map[string]cty.Value{
"id": cty.StringVal("yes"),
"ami": cty.StringVal("bar"),
})}
op, done := testOperationApply(t, "./testdata/apply-with-vars")
run, err := b.Operation(context.Background(), op)
if err != nil {
t.Fatalf("bad: %s", err)
}
go func() {
<-time.After(1 * time.Second)
run.Stop()
}()
select {
case <-run.Done():
case <-time.After(5 * time.Second):
t.Fatalf("hit the timeout. expected for the operation to finish before the timeout")
}
if run.Result != backend.OperationFailure {
t.Fatal("operation suceeded but expected to fail")
}
expectedErrHeader := "Error: No value for required variable"
if errOutput := done(t).Stderr(); !strings.Contains(errOutput, expectedErrHeader) {
t.Fatalf("unexpected error output. Expected to contain %q but it does not:\n%s", expectedErrHeader, errOutput)
}
}
func TestLocalRun_error(t *testing.T) {
configDir := "./testdata/invalid"
b := TestLocal(t)
@ -80,7 +118,7 @@ func TestLocalRun_error(t *testing.T) {
StateLocker: stateLocker,
}
_, _, diags := b.LocalRun(context.Background(), op)
_, _, diags := b.LocalRun(context.Background(), t.Context(), op)
if !diags.HasErrors() {
t.Fatal("unexpected success")
}
@ -115,7 +153,7 @@ func TestLocalRun_cloudPlan(t *testing.T) {
StateLocker: stateLocker,
}
_, _, diags := b.LocalRun(context.Background(), op)
_, _, diags := b.LocalRun(context.Background(), t.Context(), op)
if !diags.HasErrors() {
t.Fatal("unexpected success")
}
@ -201,7 +239,7 @@ func TestLocalRun_stalePlan(t *testing.T) {
StateLocker: stateLocker,
}
_, _, diags := b.LocalRun(context.Background(), op)
_, _, diags := b.LocalRun(context.Background(), t.Context(), op)
if !diags.HasErrors() {
t.Fatal("unexpected success")
}

View file

@ -88,7 +88,7 @@ func (b *Local) opPlan(
}
// Get our context
lr, configSnap, opState, ctxDiags := b.localRun(ctx, op)
lr, configSnap, opState, ctxDiags := b.localRun(ctx, stopCtx, op)
diags = diags.Append(ctxDiags)
if ctxDiags.HasErrors() {
op.ReportResult(runningOp, diags)

View file

@ -59,7 +59,7 @@ func (b *Local) opRefresh(
op.PlanRefresh = true
// Get our context
lr, _, opState, contextDiags := b.localRun(ctx, op)
lr, _, opState, contextDiags := b.localRun(ctx, stopCtx, op)
diags = diags.Append(contextDiags)
if contextDiags.HasErrors() {
op.ReportResult(runningOp, diags)

View file

@ -27,7 +27,8 @@ import (
var _ backend.Local = (*Remote)(nil)
// LocalRun implements backend.Local.
func (b *Remote) LocalRun(ctx context.Context, op *backend.Operation) (*backend.LocalRun, statemgr.Full, tfdiags.Diagnostics) {
// Refer to the comments of backend.Local for more details about ctx vs stopCtx.
func (b *Remote) LocalRun(ctx context.Context, stopCtx context.Context, op *backend.Operation) (*backend.LocalRun, statemgr.Full, tfdiags.Diagnostics) {
var diags tfdiags.Diagnostics
ret := &backend.LocalRun{
PlanOpts: &tofu.PlanOpts{

View file

@ -215,7 +215,7 @@ func TestRemoteContextWithVars(t *testing.T) {
t.Fatal(err)
}
_, _, diags := b.LocalRun(t.Context(), op)
_, _, diags := b.LocalRun(t.Context(), t.Context(), op)
if test.WantError != "" {
if !diags.HasErrors() {
@ -439,7 +439,7 @@ func TestRemoteVariablesDoNotOverride(t *testing.T) {
}
}
lr, _, diags := b.LocalRun(t.Context(), op)
lr, _, diags := b.LocalRun(t.Context(), t.Context(), op)
if diags.HasErrors() {
t.Fatalf("unexpected error\ngot: %s\nwant: <no error>", diags.Err().Error())

View file

@ -27,7 +27,8 @@ import (
var _ backend.Local = (*Cloud)(nil)
// LocalRun implements backend.Local
func (b *Cloud) LocalRun(ctx context.Context, op *backend.Operation) (*backend.LocalRun, statemgr.Full, tfdiags.Diagnostics) {
// Refer to the comments of backend.Local for more details about ctx vs stopCtx.
func (b *Cloud) LocalRun(ctx context.Context, stopCtx context.Context, op *backend.Operation) (*backend.LocalRun, statemgr.Full, tfdiags.Diagnostics) {
var diags tfdiags.Diagnostics
ret := &backend.LocalRun{
PlanOpts: &tofu.PlanOpts{

View file

@ -214,7 +214,7 @@ func TestRemoteContextWithVars(t *testing.T) {
t.Fatal(err)
}
_, _, diags := b.LocalRun(t.Context(), op)
_, _, diags := b.LocalRun(t.Context(), t.Context(), op)
if test.WantError != "" {
if !diags.HasErrors() {
@ -438,7 +438,7 @@ func TestRemoteVariablesDoNotOverride(t *testing.T) {
}
}
lr, _, diags := b.LocalRun(t.Context(), op)
lr, _, diags := b.LocalRun(t.Context(), t.Context(), op)
if diags.HasErrors() {
t.Fatalf("unexpected error\ngot: %s\nwant: <no error>", diags.Err().Error())

View file

@ -142,7 +142,9 @@ func (c *ConsoleCommand) Run(rawArgs []string) int {
}
// Get the context
lr, _, ctxDiags := local.LocalRun(ctx, opReq)
stopCtx, cancel := c.InterruptibleContext(ctx)
defer cancel()
lr, _, ctxDiags := local.LocalRun(ctx, stopCtx, opReq)
diags = diags.Append(ctxDiags)
if ctxDiags.HasErrors() {
view.Diagnostics(diags)

View file

@ -171,7 +171,9 @@ func (c *GraphCommand) Run(rawArgs []string) int {
}
// Get the context
lr, _, ctxDiags := local.LocalRun(ctx, opReq)
stopCtx, cancel := c.InterruptibleContext(ctx)
defer cancel()
lr, _, ctxDiags := local.LocalRun(ctx, stopCtx, opReq)
diags = diags.Append(ctxDiags)
if ctxDiags.HasErrors() {
view.Diagnostics(diags)

View file

@ -238,7 +238,9 @@ func (c *ImportCommand) Run(rawArgs []string) int {
}
// Get the context
lr, state, ctxDiags := local.LocalRun(ctx, opReq)
stopCtx, cancel := c.InterruptibleContext(ctx)
defer cancel()
lr, state, ctxDiags := local.LocalRun(ctx, stopCtx, opReq)
diags = diags.Append(ctxDiags)
if ctxDiags.HasErrors() {
view.Diagnostics(diags)

View file

@ -116,7 +116,9 @@ func (c *ProvidersSchemaCommand) Run(rawArgs []string) int {
}
// Get the context
lr, _, ctxDiags := local.LocalRun(ctx, opReq)
stopCtx, cancel := c.InterruptibleContext(ctx)
defer cancel()
lr, _, ctxDiags := local.LocalRun(ctx, stopCtx, opReq)
diags = diags.Append(ctxDiags)
if ctxDiags.HasErrors() {
view.Diagnostics(diags)

View file

@ -123,7 +123,9 @@ func (c *StateShowCommand) Run(rawArgs []string) int {
}
// Get the context (required to get the schemas)
lr, _, ctxDiags := local.LocalRun(ctx, opReq)
stopCtx, cancel := c.InterruptibleContext(ctx)
defer cancel()
lr, _, ctxDiags := local.LocalRun(ctx, stopCtx, opReq)
if ctxDiags.HasErrors() {
view.Diagnostics(ctxDiags)
return 1

View file

@ -14,7 +14,6 @@ import (
"io"
"log"
"os"
"os/signal"
"strings"
"sync"
"sync/atomic"
@ -46,9 +45,8 @@ type UIInput struct {
result chan string
err chan string
interrupted bool
l sync.Mutex
once sync.Once
l sync.Mutex
once sync.Once
}
func (i *UIInput) Input(ctx context.Context, opts *tofu.InputOpts) (string, error) {
@ -74,11 +72,6 @@ func (i *UIInput) Input(ctx context.Context, opts *tofu.InputOpts) (string, erro
i.l.Lock()
defer i.l.Unlock()
// If we're interrupted, then don't ask for input
if i.interrupted {
return "", errors.New("interrupted")
}
// If we have test results, return those. testInputResponse is the
// "old" way of doing it and we should remove that.
if testInputResponse != nil {
@ -101,11 +94,6 @@ func (i *UIInput) Input(ctx context.Context, opts *tofu.InputOpts) (string, erro
log.Printf("[DEBUG] command: asking for input: %q", opts.Query)
// Listen for interrupts so we can cancel the input ask
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt)
defer signal.Stop(sigCh)
// Build the output format for asking
var buf bytes.Buffer
buf.WriteString("[reset]")
@ -170,16 +158,7 @@ func (i *UIInput) Input(ctx context.Context, opts *tofu.InputOpts) (string, erro
// on a new line.
fmt.Fprintln(w)
return "", ctx.Err()
case <-sigCh:
// Print a newline so that any further output starts properly
// on a new line.
fmt.Fprintln(w)
// Mark that we were interrupted so future Ask calls fail.
i.interrupted = true
return "", errors.New("interrupted")
return "", fmt.Errorf("interrupted: %w", ctx.Err())
}
}

View file

@ -8,6 +8,7 @@ package command
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"sync/atomic"
@ -55,7 +56,7 @@ func TestUIInputInput_canceled(t *testing.T) {
// Get input until the context is canceled.
v, err := i.Input(ctx, &tofu.InputOpts{})
if err != context.Canceled {
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected a context.Canceled error, got: %v", err)
}
@ -70,20 +71,31 @@ func TestUIInputInput_canceled(t *testing.T) {
t.Fatalf("expected listening to be 1, got: %d", listening)
}
go func() {
// Fake input is given after 1 second.
time.Sleep(time.Second)
fmt.Fprint(w, "foo\n")
w.Close()
}()
v, err = i.Input(context.Background(), &tofu.InputOpts{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
// Using the same context that was cancelled should fail with the same error again when invoked again
{
_, err = i.Input(ctx, &tofu.InputOpts{})
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected a context.Canceled error, got: %v", err)
}
}
if v != "foo" {
t.Fatalf("unexpected input: %s", v)
{
// But asking for input with a new, uncancelled context, should work just fine
go func() {
// Fake input is given after 1 second.
time.Sleep(time.Second)
fmt.Fprint(w, "foo\n")
w.Close()
}()
v, err = i.Input(context.Background(), &tofu.InputOpts{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if v != "foo" {
t.Fatalf("unexpected input: %s", v)
}
}
}