diff --git a/internal/backend/local/backend_apply.go b/internal/backend/local/backend_apply.go index cde8fc93f8..8da9b5d993 100644 --- a/internal/backend/local/backend_apply.go +++ b/internal/backend/local/backend_apply.go @@ -34,15 +34,21 @@ const ( persistIntervalEnvironmentVariableName = "TF_STATE_PERSIST_INTERVAL" ) -func getEnvAsInt(envName string, defaultValue int) int { +func getEnvAsInt(envName string, defaultValue int) (int, tfdiags.Diagnostics) { + var diags tfdiags.Diagnostics if val, exists := os.LookupEnv(envName); exists { parsedVal, err := strconv.Atoi(val) - if err == nil { - return parsedVal + if err != nil { + diags = diags.Append(tfdiags.Sourceless( + tfdiags.Error, + "Invalid environment variable value", + fmt.Sprintf("The environment variable %s must be a valid integer, got %q.", envName, val), + )) + return 0, diags } - panic(fmt.Sprintf("Can't parse value '%s' of environment variable '%s'", val, envName)) + return parsedVal, diags } - return defaultValue + return defaultValue, diags } func (b *Local) opApply( @@ -112,10 +118,20 @@ func (b *Local) opApply( // stateHook uses schemas for when it periodically persists state to the // persistent storage backend. stateHook.Schemas = schemas - persistInterval := getEnvAsInt(persistIntervalEnvironmentVariableName, defaultPersistInterval) + persistInterval, intervalDiags := getEnvAsInt(persistIntervalEnvironmentVariableName, defaultPersistInterval) + diags = diags.Append(intervalDiags) + if intervalDiags.HasErrors() { + op.ReportResult(runningOp, diags) + return + } if persistInterval < defaultPersistInterval { - panic(fmt.Sprintf("Can't use value lower than %d for env variable %s, got %d", - defaultPersistInterval, persistIntervalEnvironmentVariableName, persistInterval)) + diags = diags.Append(tfdiags.Sourceless( + tfdiags.Error, + "Invalid environment variable value", + fmt.Sprintf("The environment variable %s must be at least %d, got %d.", persistIntervalEnvironmentVariableName, defaultPersistInterval, persistInterval), + )) + op.ReportResult(runningOp, diags) + return } stateHook.PersistInterval = time.Duration(persistInterval) * time.Second diff --git a/internal/backend/local/backend_apply_test.go b/internal/backend/local/backend_apply_test.go index ae973058e0..57df94e6e9 100644 --- a/internal/backend/local/backend_apply_test.go +++ b/internal/backend/local/backend_apply_test.go @@ -428,3 +428,108 @@ func TestApply_applyCanceledAutoApprove(t *testing.T) { } } + +func TestGetEnvAsInt(t *testing.T) { + const testEnv = "TEST_GET_ENV_AS_INT" + + tests := []struct { + name string + envValue string + defaultValue int + wantValue int + wantError bool + }{ + { + name: "env not set returns default", + envValue: "", + defaultValue: 20, + wantValue: 20, + wantError: false, + }, + { + name: "valid integer is parsed", + envValue: "30", + defaultValue: 20, + wantValue: 30, + wantError: false, + }, + { + name: "non-integer value returns error", + envValue: "abc", + defaultValue: 20, + wantError: true, + }, + { + name: "float value returns error", + envValue: "1.5", + defaultValue: 20, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.envValue != "" { + t.Setenv(testEnv, tt.envValue) + } + + got, diags := getEnvAsInt(testEnv, tt.defaultValue) + if tt.wantError { + if !diags.HasErrors() { + t.Errorf("expected error but got none, value=%d", got) + } + return + } + if diags.HasErrors() { + t.Fatalf("unexpected error: %s", diags.Err()) + } + if got != tt.wantValue { + t.Errorf("got %d, want %d", got, tt.wantValue) + } + }) + } +} + +func TestLocal_applyInvalidPersistInterval(t *testing.T) { + t.Run("non-integer value causes error diagnostic", func(t *testing.T) { + t.Setenv(persistIntervalEnvironmentVariableName, "abc") + + b := TestLocal(t) + TestLocalProvider(t, b, "test", applyFixtureSchema()) + + op, done := testOperationApply(t, "./testdata/apply") + + run, err := b.Operation(context.Background(), op) + if err != nil { + t.Fatalf("unexpected error starting operation: %v", err) + } + <-run.Done() + if run.Result == backend.OperationSuccess { + t.Fatalf("expected operation to fail with invalid %s=abc", persistIntervalEnvironmentVariableName) + } + if got, want := done(t).Stderr(), "Invalid environment variable value"; !strings.Contains(got, want) { + t.Errorf("expected stderr to contain %q, got:\n%s", want, got) + } + }) + + t.Run("below minimum value causes error diagnostic", func(t *testing.T) { + t.Setenv(persistIntervalEnvironmentVariableName, "5") + + b := TestLocal(t) + TestLocalProvider(t, b, "test", applyFixtureSchema()) + + op, done := testOperationApply(t, "./testdata/apply") + + run, err := b.Operation(context.Background(), op) + if err != nil { + t.Fatalf("unexpected error starting operation: %v", err) + } + <-run.Done() + if run.Result == backend.OperationSuccess { + t.Fatalf("expected operation to fail with invalid %s=5", persistIntervalEnvironmentVariableName) + } + if got, want := done(t).Stderr(), "Invalid environment variable value"; !strings.Contains(got, want) { + t.Errorf("expected stderr to contain %q, got:\n%s", want, got) + } + }) +}