diff --git a/internal/tofu/context_test.go b/internal/tofu/context_test.go index 45313bdd76..a6b2cac5bd 100644 --- a/internal/tofu/context_test.go +++ b/internal/tofu/context_test.go @@ -21,6 +21,7 @@ import ( "github.com/hashicorp/go-version" "github.com/zclconf/go-cty/cty" + "github.com/opentofu/opentofu/internal/addrs" "github.com/opentofu/opentofu/internal/configs" "github.com/opentofu/opentofu/internal/configs/configload" "github.com/opentofu/opentofu/internal/configs/configschema" @@ -33,6 +34,7 @@ import ( "github.com/opentofu/opentofu/internal/states" "github.com/opentofu/opentofu/internal/states/statefile" "github.com/opentofu/opentofu/internal/tfdiags" + "github.com/opentofu/opentofu/internal/tracing" tfversion "github.com/opentofu/opentofu/version" ) @@ -255,6 +257,71 @@ resource "implicit_thing" "b" { } } +func TestContext_contextValuesPropagation(t *testing.T) { + // This test verifies that our code is propagating context.Context values + // through the system at least well enough that they can reach provider + // calls. It does so using [tracing.ContextProbe], which is a helper for + // probing to make sure that values (in this case, the probe itself) + // are able to reach calls to [tracing.ContextProbeReport] that are included + // in the [MockProvider] methods. + + ctx, probe := tracing.NewContextProbe(t, t.Context()) + tofuCtx := testContext2(t, &ContextOpts{ + Providers: map[addrs.Provider]providers.Factory{ + addrs.NewBuiltInProvider("test"): providers.FactoryFixed(&MockProvider{ + GetProviderSchemaResponse: &providers.GetProviderSchemaResponse{ + Provider: providers.Schema{ + Block: &configschema.Block{}, + }, + ResourceTypes: map[string]providers.Schema{ + "test": { + Block: &configschema.Block{}, + }, + }, + DataSources: map[string]providers.Schema{ + "test": { + Block: &configschema.Block{}, + }, + }, + }, + ReadDataSourceResponse: &providers.ReadDataSourceResponse{ + State: cty.EmptyObjectVal, + }, + }), + }, + }) + m := testModuleInline(t, map[string]string{ + "main.tf": ` + terraform { + required_providers { + test = { + source = "terraform.io/builtin/test" + } + } + } + resource "test" "test" {} + data "test" "test" {} + `, + }) + + plan, diags := tofuCtx.Plan(ctx, m, states.NewState(), DefaultPlanOpts) + assertNoErrors(t, diags) + _, diags = tofuCtx.Apply(ctx, plan, m) + assertNoErrors(t, diags) + + probe.ExpectReportsFrom(t, + "github.com/opentofu/opentofu/internal/tofu.(*MockProvider).GetProviderSchema", + "github.com/opentofu/opentofu/internal/tofu.(*MockProvider).ValidateProviderConfig", + "github.com/opentofu/opentofu/internal/tofu.(*MockProvider).ValidateDataResourceConfig", + "github.com/opentofu/opentofu/internal/tofu.(*MockProvider).ValidateResourceConfig", + //"github.com/opentofu/opentofu/internal/tofu.(*MockProvider).Configure", // FIXME: Not working yet + "github.com/opentofu/opentofu/internal/tofu.(*MockProvider).ReadDataSource", + "github.com/opentofu/opentofu/internal/tofu.(*MockProvider).PlanResourceChange", + "github.com/opentofu/opentofu/internal/tofu.(*MockProvider).ApplyResourceChange", + "github.com/opentofu/opentofu/internal/tofu.(*MockProvider).Close", + ) +} + func testContext2(t testing.TB, opts *ContextOpts) *Context { t.Helper() diff --git a/internal/tofu/provider_mock.go b/internal/tofu/provider_mock.go index 99e582cb29..da9765016d 100644 --- a/internal/tofu/provider_mock.go +++ b/internal/tofu/provider_mock.go @@ -16,6 +16,7 @@ import ( "github.com/opentofu/opentofu/internal/configs/hcl2shim" "github.com/opentofu/opentofu/internal/providers" + "github.com/opentofu/opentofu/internal/tracing" ) var _ providers.Interface = (*MockProvider)(nil) @@ -107,7 +108,8 @@ type MockProvider struct { CloseError error } -func (p *MockProvider) GetProviderSchema(_ context.Context) providers.GetProviderSchemaResponse { +func (p *MockProvider) GetProviderSchema(ctx context.Context) providers.GetProviderSchemaResponse { + tracing.ContextProbeReport(ctx, 0) p.Lock() defer p.Unlock() p.GetProviderSchemaCalled = true @@ -129,7 +131,8 @@ func (p *MockProvider) getProviderSchema() providers.GetProviderSchemaResponse { } } -func (p *MockProvider) ValidateProviderConfig(_ context.Context, r providers.ValidateProviderConfigRequest) (resp providers.ValidateProviderConfigResponse) { +func (p *MockProvider) ValidateProviderConfig(ctx context.Context, r providers.ValidateProviderConfigRequest) (resp providers.ValidateProviderConfigResponse) { + tracing.ContextProbeReport(ctx, 0) p.Lock() defer p.Unlock() @@ -147,7 +150,8 @@ func (p *MockProvider) ValidateProviderConfig(_ context.Context, r providers.Val return resp } -func (p *MockProvider) ValidateResourceConfig(_ context.Context, r providers.ValidateResourceConfigRequest) (resp providers.ValidateResourceConfigResponse) { +func (p *MockProvider) ValidateResourceConfig(ctx context.Context, r providers.ValidateResourceConfigRequest) (resp providers.ValidateResourceConfigResponse) { + tracing.ContextProbeReport(ctx, 0) p.Lock() defer p.Unlock() @@ -179,7 +183,8 @@ func (p *MockProvider) ValidateResourceConfig(_ context.Context, r providers.Val return resp } -func (p *MockProvider) ValidateDataResourceConfig(_ context.Context, r providers.ValidateDataResourceConfigRequest) (resp providers.ValidateDataResourceConfigResponse) { +func (p *MockProvider) ValidateDataResourceConfig(ctx context.Context, r providers.ValidateDataResourceConfigRequest) (resp providers.ValidateDataResourceConfigResponse) { + tracing.ContextProbeReport(ctx, 0) p.Lock() defer p.Unlock() @@ -209,7 +214,8 @@ func (p *MockProvider) ValidateDataResourceConfig(_ context.Context, r providers return resp } -func (p *MockProvider) UpgradeResourceState(_ context.Context, r providers.UpgradeResourceStateRequest) (resp providers.UpgradeResourceStateResponse) { +func (p *MockProvider) UpgradeResourceState(ctx context.Context, r providers.UpgradeResourceStateRequest) (resp providers.UpgradeResourceStateResponse) { + tracing.ContextProbeReport(ctx, 0) p.Lock() defer p.Unlock() @@ -258,7 +264,8 @@ func (p *MockProvider) UpgradeResourceState(_ context.Context, r providers.Upgra return resp } -func (p *MockProvider) MoveResourceState(_ context.Context, r providers.MoveResourceStateRequest) providers.MoveResourceStateResponse { +func (p *MockProvider) MoveResourceState(ctx context.Context, r providers.MoveResourceStateRequest) providers.MoveResourceStateResponse { + tracing.ContextProbeReport(ctx, 0) var resp providers.MoveResourceStateResponse p.Lock() defer p.Unlock() @@ -311,7 +318,8 @@ func (p *MockProvider) MoveResourceState(_ context.Context, r providers.MoveReso return resp } -func (p *MockProvider) ConfigureProvider(_ context.Context, r providers.ConfigureProviderRequest) (resp providers.ConfigureProviderResponse) { +func (p *MockProvider) ConfigureProvider(ctx context.Context, r providers.ConfigureProviderRequest) (resp providers.ConfigureProviderResponse) { + tracing.ContextProbeReport(ctx, 0) p.Lock() defer p.Unlock() @@ -329,7 +337,9 @@ func (p *MockProvider) ConfigureProvider(_ context.Context, r providers.Configur return resp } -func (p *MockProvider) Stop(_ context.Context) error { +func (p *MockProvider) Stop(ctx context.Context) error { + tracing.ContextProbeReport(ctx, 0) + // We intentionally don't lock in this one because the whole point of this // method is to be called concurrently with another operation that can // be cancelled. The provider itself is responsible for handling @@ -343,7 +353,8 @@ func (p *MockProvider) Stop(_ context.Context) error { return p.StopResponse } -func (p *MockProvider) ReadResource(_ context.Context, r providers.ReadResourceRequest) (resp providers.ReadResourceResponse) { +func (p *MockProvider) ReadResource(ctx context.Context, r providers.ReadResourceRequest) (resp providers.ReadResourceResponse) { + tracing.ContextProbeReport(ctx, 0) p.Lock() defer p.Unlock() @@ -384,7 +395,8 @@ func (p *MockProvider) ReadResource(_ context.Context, r providers.ReadResourceR return resp } -func (p *MockProvider) PlanResourceChange(_ context.Context, r providers.PlanResourceChangeRequest) (resp providers.PlanResourceChangeResponse) { +func (p *MockProvider) PlanResourceChange(ctx context.Context, r providers.PlanResourceChangeRequest) (resp providers.PlanResourceChangeResponse) { + tracing.ContextProbeReport(ctx, 0) p.Lock() defer p.Unlock() @@ -464,7 +476,8 @@ func (p *MockProvider) PlanResourceChange(_ context.Context, r providers.PlanRes return resp } -func (p *MockProvider) ApplyResourceChange(_ context.Context, r providers.ApplyResourceChangeRequest) (resp providers.ApplyResourceChangeResponse) { +func (p *MockProvider) ApplyResourceChange(ctx context.Context, r providers.ApplyResourceChangeRequest) (resp providers.ApplyResourceChangeResponse) { + tracing.ContextProbeReport(ctx, 0) p.Lock() defer p.Unlock() p.ApplyResourceChangeCalled = true @@ -519,7 +532,8 @@ func (p *MockProvider) ApplyResourceChange(_ context.Context, r providers.ApplyR return resp } -func (p *MockProvider) ImportResourceState(_ context.Context, r providers.ImportResourceStateRequest) (resp providers.ImportResourceStateResponse) { +func (p *MockProvider) ImportResourceState(ctx context.Context, r providers.ImportResourceStateRequest) (resp providers.ImportResourceStateResponse) { + tracing.ContextProbeReport(ctx, 0) p.Lock() defer p.Unlock() @@ -561,7 +575,8 @@ func (p *MockProvider) ImportResourceState(_ context.Context, r providers.Import return resp } -func (p *MockProvider) ReadDataSource(_ context.Context, r providers.ReadDataSourceRequest) (resp providers.ReadDataSourceResponse) { +func (p *MockProvider) ReadDataSource(ctx context.Context, r providers.ReadDataSourceRequest) (resp providers.ReadDataSourceResponse) { + tracing.ContextProbeReport(ctx, 0) p.Lock() defer p.Unlock() @@ -584,7 +599,8 @@ func (p *MockProvider) ReadDataSource(_ context.Context, r providers.ReadDataSou return resp } -func (p *MockProvider) GetFunctions(_ context.Context) (resp providers.GetFunctionsResponse) { +func (p *MockProvider) GetFunctions(ctx context.Context) (resp providers.GetFunctionsResponse) { + tracing.ContextProbeReport(ctx, 0) p.Lock() defer p.Unlock() @@ -600,7 +616,8 @@ func (p *MockProvider) GetFunctions(_ context.Context) (resp providers.GetFuncti return resp } -func (p *MockProvider) CallFunction(_ context.Context, r providers.CallFunctionRequest) (resp providers.CallFunctionResponse) { +func (p *MockProvider) CallFunction(ctx context.Context, r providers.CallFunctionRequest) (resp providers.CallFunctionResponse) { + tracing.ContextProbeReport(ctx, 0) p.Lock() defer p.Unlock() @@ -617,7 +634,8 @@ func (p *MockProvider) CallFunction(_ context.Context, r providers.CallFunctionR return resp } -func (p *MockProvider) Close(_ context.Context) error { +func (p *MockProvider) Close(ctx context.Context) error { + tracing.ContextProbeReport(ctx, 0) p.Lock() defer p.Unlock() diff --git a/internal/tracing/context_probe.go b/internal/tracing/context_probe.go new file mode 100644 index 0000000000..3463dfaa42 --- /dev/null +++ b/internal/tracing/context_probe.go @@ -0,0 +1,120 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package tracing + +import ( + "context" + "iter" + "maps" + "runtime" + "sync" + "testing" +) + +// ContextProbe is a testing helper to allow tests to check whether +// [context.Context] values are being propagated correctly to various downstream +// functions where context value continuity is important for certain +// functionality, like tracing. (It's in this package because tracing is our +// primary motivation, but could potentially be used for other +// context-value-related situations too.) +// +// To use it, first call [NewContextProbe] from the test that wants to verify +// propagation, which returns both a [ContextProbe] and a [context.Context] +// that carries a value referring to it. Then in the function whose +// functionality requires context values to reach it, call [ContextProbeReport] +// with that function's own local context to notify any active context probe +// that the function was called. Finally, at the end of the test call +// [ContextProbe.ExpectReportsFrom] with all of the functions that the test +// expects should have been able to successfully call [ContextProbeReport]. +type ContextProbe struct { + calls map[string]struct{} + mu sync.Mutex +} + +type contextProbeKeyType int + +const contextProbeKey = contextProbeKeyType(0) + +// NewContextProbe creates a new [ContextProbe] and a new context (child of base) +// that is bound to it, so that [ContextProbeReport] with that context would +// record the call in the probe. +func NewContextProbe(t testing.TB, base context.Context) (context.Context, *ContextProbe) { + if existing := base.Value(contextProbeKey); existing != nil { + // We can only have one at a time so this is likely to be a programming + // error in the calling test, and so we'll report it explicitly rather + // than just quietly doing something confusing. + t.Fatal("base context already has a ContextProbe") + } + probe := &ContextProbe{ + calls: make(map[string]struct{}), + } + ctx := context.WithValue(base, contextProbeKey, probe) + return ctx, probe +} + +func (p *ContextProbe) report(f *runtime.Func) { + p.mu.Lock() + p.calls[f.Name()] = struct{}{} + p.mu.Unlock() +} + +// ExpectReportsFrom generates test errors (but does not terminate the test) +// if any of the given function names have not yet been reported by a +// call to [ContextProbeReport]. +// +// Returns true if no errors were generated, or false if at least one error +// was generated. +func (p *ContextProbe) ExpectReportsFrom(t testing.TB, names ...string) bool { + ret := true + for _, name := range names { + if _, called := p.calls[name]; !called { + t.Error("tracing.ContextProbeReport was not called by " + name) + ret = false + } + } + return ret +} + +// FunctionsReported returns an interable sequence of all of the functions +// that have called [ContextProbeReport] so far, in no particular order. +// +// Most tests should prefer to use [ContextProbe.ExpectReportsFrom] so that +// they don't get broken by reports intended for use by other tests, but +// this can be useful as a temporary addition to a test for debugging purposes, +// or to find out how the Go runtime describes a particular function of +// interest. +func (p *ContextProbe) FunctionsReported() iter.Seq[string] { + return maps.Keys(p.calls) +} + +// ContextProbeReport notifies the [ContextProbe] in the given context, if any, +// that its caller has been called. +// +// skipFrames is the number of callers to skip when deciding the name of the +// caller. Zero means to record the direct caller of ContextProbeReport. +// +// When called with a context that does not have a [ContextProbe] this does +// only the minimum work required to determine that there is no probe and +// immediately returns. The overhead is small, but there is still some overhead +// and so this function should not be called from functions used in tight loops +// but is okay to leave in normal codepaths otherwise. +func ContextProbeReport(ctx context.Context, skipFrames int) { + probe, ok := ctx.Value(contextProbeKey).(*ContextProbe) + if !ok { + return // fast return path for the no-probe case, to minimize overhead + } + + callerPc, _, _, ok := runtime.Caller(skipFrames + 1) + if !ok { + return + } + caller := runtime.FuncForPC(callerPc) + if caller == nil { + return + } + + probe.report(caller) +}