From a69d19d9f3a118aa2a6c4fc18b44783cf95681ba Mon Sep 17 00:00:00 2001 From: Christian Mesh Date: Thu, 18 Apr 2024 09:11:38 -0400 Subject: [PATCH] Allow configured providers to provide additional functions. (#1491) Signed-off-by: Christian Mesh --- go.mod | 2 + go.sum | 4 +- internal/addrs/parse_ref.go | 17 ++- internal/addrs/provider_function.go | 78 ++++++++++++ internal/builtin/providers/tf/provider.go | 4 + internal/lang/blocktoattr/fixup.go | 10 ++ internal/lang/blocktoattr/functions.go | 45 +++++++ internal/lang/eval.go | 50 ++++---- internal/lang/eval_test.go | 36 ++---- internal/lang/functions.go | 8 +- internal/lang/references.go | 23 +++- internal/lang/scope.go | 5 +- internal/legacy/tofu/provider_mock.go | 4 + internal/plugin/grpc_provider.go | 35 +++++- internal/plugin6/grpc_provider.go | 35 +++++- internal/provider-simple-v6/provider.go | 4 + internal/provider-simple/provider.go | 4 + internal/providers/provider.go | 12 +- internal/tfdiags/hcl_test.go | 4 + internal/tofu/context_functions.go | 139 +++++++++++++++------ internal/tofu/context_functions_test.go | 141 +++++++++++++--------- internal/tofu/context_plugins.go | 64 +--------- internal/tofu/eval_context_builtin.go | 37 ++---- internal/tofu/evaluate.go | 23 ++-- internal/tofu/graph_builder_apply.go | 6 + internal/tofu/graph_builder_eval.go | 6 + internal/tofu/graph_builder_plan.go | 6 + internal/tofu/node_root_variable_test.go | 4 + internal/tofu/provider_mock.go | 20 +++ internal/tofu/transform_provider.go | 132 +++++++++++++++++++- internal/tofu/transform_provider_test.go | 34 +++++- internal/tofu/transform_reference.go | 4 + 32 files changed, 720 insertions(+), 276 deletions(-) create mode 100644 internal/addrs/provider_function.go create mode 100644 internal/lang/blocktoattr/functions.go diff --git a/go.mod b/go.mod index a89079eacd..05668db27b 100644 --- a/go.mod +++ b/go.mod @@ -267,3 +267,5 @@ require ( ) go 1.21 + +replace github.com/hashicorp/hcl/v2 v2.20.1 => github.com/opentofu/hcl/v2 v2.0.0-20240416130056-03228b26f391 diff --git a/go.sum b/go.sum index db35cc8dcf..58dc97489e 100644 --- a/go.sum +++ b/go.sum @@ -716,8 +716,6 @@ github.com/hashicorp/golang-lru v0.5.1 h1:0hERBMJE1eitiLkihrMvRVBYAkpHzc/J3QdDN+ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/hashicorp/hcl/v2 v2.20.1 h1:M6hgdyz7HYt1UN9e61j+qKJBqR3orTWbI1HKBJEdxtc= -github.com/hashicorp/hcl/v2 v2.20.1/go.mod h1:TZDqQ4kNKCbh1iJp99FdPiUaVDDUPivbqxZulxDYqL4= github.com/hashicorp/jsonapi v0.0.0-20210826224640-ee7dae0fb22d h1:9ARUJJ1VVynB176G1HCwleORqCaXm/Vx0uUi0dL26I0= github.com/hashicorp/jsonapi v0.0.0-20210826224640-ee7dae0fb22d/go.mod h1:Yog5+CPEM3c99L1CL2CFCYoSzgWm5vTU58idbRUaLik= github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= @@ -927,6 +925,8 @@ github.com/onsi/gomega v1.10.1 h1:o0+MgICZLuZ7xjH7Vx6zS/zcu93/BEp1VwkIW1mEXCE= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/openbao/openbao/api v0.0.0-20240326035453-c075f0ef2c7e h1:LIQFfqW6BA5E2ycx8NNDgyKh0exFubHePM5pF3knogo= github.com/openbao/openbao/api v0.0.0-20240326035453-c075f0ef2c7e/go.mod h1:NUvBdXCNlmAGQ9TbYV7vS1Y9awHAjrq3QLiBWV+4Glk= +github.com/opentofu/hcl/v2 v2.0.0-20240416130056-03228b26f391 h1:Z2YGMhYBvmXBZlQdnlembuV4sp0lPJphIfgM9fVSjpU= +github.com/opentofu/hcl/v2 v2.0.0-20240416130056-03228b26f391/go.mod h1:TZDqQ4kNKCbh1iJp99FdPiUaVDDUPivbqxZulxDYqL4= github.com/opentofu/registry-address v0.0.0-20230920144404-f1e51167f633 h1:81TBkM/XGIFlVvyabp0CJl00UHeVUiQjz0fddLMi848= github.com/opentofu/registry-address v0.0.0-20230920144404-f1e51167f633/go.mod h1:HzQhpVo/NJnGmN+7FPECCVCA5ijU7AUcvf39enBKYOc= github.com/packer-community/winrmcp v0.0.0-20180921211025-c76d91c1e7db h1:9uViuKtx1jrlXLBW/pMnhOfzn3iSEdLase/But/IZRU= diff --git a/internal/addrs/parse_ref.go b/internal/addrs/parse_ref.go index e0dace786b..a448c08162 100644 --- a/internal/addrs/parse_ref.go +++ b/internal/addrs/parse_ref.go @@ -340,7 +340,6 @@ func parseRef(traversal hcl.Traversal) (*Reference, tfdiags.Diagnostics) { SourceRange: tfdiags.SourceRangeFromHCL(rng), Remaining: remain, }, diags - case "template", "lazy", "arg": // These names are all pre-emptively reserved in the hope of landing // some version of "template values" or "lazy expressions" feature @@ -354,6 +353,22 @@ func parseRef(traversal hcl.Traversal) (*Reference, tfdiags.Diagnostics) { return nil, diags default: + function := ParseFunction(root) + if function.IsNamespace(FunctionNamespaceProvider) { + pf, err := function.AsProviderFunction() + if err != nil { + return nil, diags.Append(&hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Unable to parse provider function", + Detail: err.Error(), + Subject: rootRange.Ptr(), + }) + } + return &Reference{ + Subject: pf, + SourceRange: tfdiags.SourceRangeFromHCL(rootRange), + }, diags + } return parseResourceRef(ManagedResourceMode, rootRange, traversal) } } diff --git a/internal/addrs/provider_function.go b/internal/addrs/provider_function.go new file mode 100644 index 0000000000..83d4ac9f6c --- /dev/null +++ b/internal/addrs/provider_function.go @@ -0,0 +1,78 @@ +package addrs + +import ( + "fmt" + "strings" +) + +// ProviderFunction is the address of a provider defined function. +type ProviderFunction struct { + referenceable + ProviderName string + ProviderAlias string + Function string +} + +func (v ProviderFunction) String() string { + if v.ProviderAlias != "" { + return fmt.Sprintf("provider::%s::%s::%s", v.ProviderName, v.ProviderAlias, v.Function) + } + return fmt.Sprintf("provider::%s::%s", v.ProviderName, v.Function) +} + +func (v ProviderFunction) UniqueKey() UniqueKey { + return v // A ProviderFunction is its own UniqueKey +} + +func (v ProviderFunction) uniqueKeySigil() {} + +type Function struct { + Namespaces []string + Name string +} + +const ( + FunctionNamespaceProvider = "provider" + FunctionNamespaceCore = "core" +) + +var FunctionNamespaces = []string{ + FunctionNamespaceProvider, + FunctionNamespaceCore, +} + +func ParseFunction(input string) Function { + parts := strings.Split(input, "::") + return Function{ + Name: parts[len(parts)-1], + Namespaces: parts[:len(parts)-1], + } +} + +func (f Function) String() string { + return strings.Join(append(f.Namespaces, f.Name), "::") +} + +func (f Function) IsNamespace(namespace string) bool { + return len(f.Namespaces) > 0 && f.Namespaces[0] == namespace +} + +func (f Function) AsProviderFunction() (pf ProviderFunction, err error) { + if !f.IsNamespace(FunctionNamespaceProvider) { + // Should always be checked ahead of time! + panic("BUG: non-provider function " + f.String()) + } + + if len(f.Namespaces) == 2 { + // provider:::: + pf.ProviderName = f.Namespaces[1] + } else if len(f.Namespaces) == 3 { + // provider:::::: + pf.ProviderName = f.Namespaces[1] + pf.ProviderAlias = f.Namespaces[2] + } else { + return pf, fmt.Errorf("invalid provider function %q: expected provider:::: or provider::::::", f) + } + pf.Function = f.Name + return pf, nil +} diff --git a/internal/builtin/providers/tf/provider.go b/internal/builtin/providers/tf/provider.go index b8a581a71f..87c519f68c 100644 --- a/internal/builtin/providers/tf/provider.go +++ b/internal/builtin/providers/tf/provider.go @@ -160,6 +160,10 @@ func (p *Provider) ValidateResourceConfig(req providers.ValidateResourceConfigRe return validateDataStoreResourceConfig(req) } +func (p *Provider) GetFunctions() providers.GetFunctionsResponse { + panic("unimplemented - terraform provider has no functions") +} + func (p *Provider) CallFunction(r providers.CallFunctionRequest) providers.CallFunctionResponse { panic("unimplemented - terraform provider has no functions") } diff --git a/internal/lang/blocktoattr/fixup.go b/internal/lang/blocktoattr/fixup.go index 7f6b17b86d..69678fbfaf 100644 --- a/internal/lang/blocktoattr/fixup.go +++ b/internal/lang/blocktoattr/fixup.go @@ -249,6 +249,16 @@ func (e *fixupBlocksExpr) Variables() []hcl.Traversal { return ret } +func (e *fixupBlocksExpr) Functions() []hcl.Traversal { + var ret []hcl.Traversal + schema := SchemaForCtyElementType(e.ety) + spec := schema.DecoderSpec() + for _, block := range e.blocks { + ret = append(ret, hcldec.Functions(block.Body, spec)...) + } + return ret +} + func (e *fixupBlocksExpr) Range() hcl.Range { // This is not really an appropriate range for the expression but it's // the best we can do from here. diff --git a/internal/lang/blocktoattr/functions.go b/internal/lang/blocktoattr/functions.go new file mode 100644 index 0000000000..8781c324e5 --- /dev/null +++ b/internal/lang/blocktoattr/functions.go @@ -0,0 +1,45 @@ +package blocktoattr + +import ( + "github.com/hashicorp/hcl/v2" + "github.com/hashicorp/hcl/v2/ext/dynblock" + "github.com/hashicorp/hcl/v2/hcldec" + "github.com/opentofu/opentofu/internal/configs/configschema" +) + +// ExpandedFunctions finds all of the global functions referenced in the +// given body with the given schema while taking into account the possibilities +// both of "dynamic" blocks being expanded and the possibility of certain +// attributes being written instead as nested blocks as allowed by the +// FixUpBlockAttrs function. +// +// This function exists to allow functions to be analyzed prior to dynamic +// block expansion while also dealing with the fact that dynamic block expansion +// might in turn produce nested blocks that are subject to FixUpBlockAttrs. +// +// This is intended as a drop-in replacement for dynblock.FunctionsHCLDec, +// which is itself a drop-in replacement for hcldec.Functions. +func ExpandedFunctions(body hcl.Body, schema *configschema.Block) []hcl.Traversal { + rootNode := dynblock.WalkFunctions(body) + return walkFunctions(rootNode, body, schema) +} + +func walkFunctions(node dynblock.WalkFunctionsNode, body hcl.Body, schema *configschema.Block) []hcl.Traversal { + givenRawSchema := hcldec.ImpliedSchema(schema.DecoderSpec()) + ambiguousNames := ambiguousNames(schema) + effectiveRawSchema := effectiveSchema(givenRawSchema, body, ambiguousNames, false) + vars, children := node.Visit(effectiveRawSchema) + + for _, child := range children { + if blockS, exists := schema.BlockTypes[child.BlockTypeName]; exists { + vars = append(vars, walkFunctions(child.Node, child.Body(), &blockS.Block)...) + } else if attrS, exists := schema.Attributes[child.BlockTypeName]; exists && attrS.Type.IsCollectionType() && attrS.Type.ElementType().IsObjectType() { + // ☝️Check for collection type before element type, because if this is a mis-placed reference, + // a panic here will prevent other useful diags from being elevated to show the user what to fix + synthSchema := SchemaForCtyElementType(attrS.Type.ElementType()) + vars = append(vars, walkFunctions(child.Node, child.Body(), synthSchema)...) + } + } + + return vars +} diff --git a/internal/lang/eval.go b/internal/lang/eval.go index b093e506d2..e3e5f8fc4a 100644 --- a/internal/lang/eval.go +++ b/internal/lang/eval.go @@ -7,7 +7,6 @@ package lang import ( "fmt" - "regexp" "strings" "github.com/hashicorp/hcl/v2" @@ -199,9 +198,6 @@ func (s *Scope) EvalExpr(expr hcl.Expression, wantType cty.Type) (cty.Value, tfd return val, diags } -// Common provider function namespace form -var providerFuncNamespace = regexp.MustCompile("^([^:]*)::([^:]*)::$") - // Identify and enhance any function related dialogs produced by a hcl.EvalContext func (s *Scope) enhanceFunctionDiags(diags hcl.Diagnostics) hcl.Diagnostics { out := make(hcl.Diagnostics, len(diags)) @@ -213,7 +209,7 @@ func (s *Scope) enhanceFunctionDiags(diags hcl.Diagnostics) hcl.Diagnostics { // prefix::stuff:: fullNamespace := funcExtra.CalledFunctionNamespace() - if !strings.Contains(fullNamespace, "::") { + if len(fullNamespace) == 0 { // Not a namespaced function, no enhancements nessesary continue } @@ -224,32 +220,23 @@ func (s *Scope) enhanceFunctionDiags(diags hcl.Diagnostics) hcl.Diagnostics { // Update enhanced with additional details - if fullNamespace == CoreNamespace { + fn := addrs.ParseFunction(fullNamespace + funcName) + + if fn.IsNamespace(addrs.FunctionNamespaceCore) { // Error is in core namespace, mirror non-core equivalent enhanced.Summary = "Call to unknown function" - enhanced.Detail = fmt.Sprintf("There is no builtin (%s) function named %q.", CoreNamespace, funcName) - continue - } - - match := providerFuncNamespace.FindSubmatch([]byte(fullNamespace)) - if match == nil || string(match[1]) != "provider" { - // complete mismatch or invalid prefix - enhanced.Summary = "Invalid function format" - enhanced.Detail = fmt.Sprintf("Expected provider::::, instead found \"%s%s\"", fullNamespace, funcName) - continue - } - - providerName := string(match[2]) - addr, ok := s.ProviderNames[providerName] - if !ok { - // Provider not registered - enhanced.Summary = "Unknown function provider" - enhanced.Detail = fmt.Sprintf("Provider %q does not exist within the required_providers of this module", providerName) + enhanced.Detail = fmt.Sprintf("There is no builtin (%s::) function named %q.", addrs.FunctionNamespaceCore, funcName) + } else if fn.IsNamespace(addrs.FunctionNamespaceProvider) { + if _, err := fn.AsProviderFunction(); err != nil { + // complete mismatch or invalid prefix + enhanced.Summary = "Invalid function format" + enhanced.Detail = err.Error() + } } else { - // Func not in provider - enhanced.Summary = "Function not found in provider" - enhanced.Detail = fmt.Sprintf("Function %q was not registered by provider named %q of type %q", funcName, providerName, addr) + enhanced.Summary = "Unknown function namespace" + enhanced.Detail = fmt.Sprintf("Function %q does not exist within a valid namespace (%s)", fn, strings.Join(addrs.FunctionNamespaces, ",")) } + // Function / Provider not found handled by eval_context_builtin.go } } return out @@ -478,7 +465,16 @@ func (s *Scope) evalContext(refs []*addrs.Reference, selfAddr addrs.Referenceabl val, valDiags := normalizeRefValue(s.Data.GetCheckBlock(subj, rng)) diags = diags.Append(valDiags) outputValues[subj.Name] = val + case addrs.ProviderFunction: + // Inject function directly into context + if _, ok := ctx.Functions[subj.String()]; !ok { + fn, fnDiags := s.ProviderFunctions(subj, rng) + diags = diags.Append(fnDiags) + if !fnDiags.HasErrors() { + ctx.Functions[subj.String()] = *fn + } + } default: // Should never happen panic(fmt.Errorf("Scope.buildEvalContext cannot handle address type %T", rawSubj)) diff --git a/internal/lang/eval_test.go b/internal/lang/eval_test.go index 2b3d24bdc6..3b8ef7fca4 100644 --- a/internal/lang/eval_test.go +++ b/internal/lang/eval_test.go @@ -874,26 +874,14 @@ func Test_enhanceFunctionDiags(t *testing.T) { { "Invalid prefix", "attr = magic::missing_function(54)", + "Unknown function namespace", + "Function \"magic::missing_function\" does not exist within a valid namespace (provider,core)", + }, + { + "Too many namespaces", + "attr = provider::foo::bar::extra::extra2::missing_function(54)", "Invalid function format", - "Expected provider::::, instead found \"magic::missing_function\"", - }, - { - "Broken prefix", - "attr = magic::foo::bar::extra::missing_function(54)", - "Invalid function format", - "Expected provider::::, instead found \"magic::foo::bar::extra::missing_function\"", - }, - { - "Missing provider", - "attr = provider::unknown::func(54)", - "Unknown function provider", - "Provider \"unknown\" does not exist within the required_providers of this module", - }, - { - "Missing function", - "attr = provider::known::func(54)", - "Function not found in provider", - "Function \"func\" was not registered by provider named \"known\" of type \"hostname/namespace/type\"", + "invalid provider function \"provider::foo::bar::extra::extra2::missing_function\": expected provider:::: or provider::::::", }, } @@ -919,15 +907,7 @@ func Test_enhanceFunctionDiags(t *testing.T) { body := file.Body - scope := &Scope{ - ProviderNames: map[string]addrs.Provider{ - "known": addrs.Provider{ - Type: "type", - Namespace: "namespace", - Hostname: "hostname", - }, - }, - } + scope := &Scope{} ctx, ctxDiags := scope.EvalContext(nil) if ctxDiags.HasErrors() { diff --git a/internal/lang/functions.go b/internal/lang/functions.go index 72f666ba80..73e05bf84a 100644 --- a/internal/lang/functions.go +++ b/internal/lang/functions.go @@ -14,6 +14,7 @@ import ( "github.com/zclconf/go-cty/cty/function" "github.com/zclconf/go-cty/cty/function/stdlib" + "github.com/opentofu/opentofu/internal/addrs" "github.com/opentofu/opentofu/internal/experiments" "github.com/opentofu/opentofu/internal/lang/funcs" ) @@ -24,7 +25,8 @@ var impureFunctions = []string{ "uuid", } -const CoreNamespace = "core::" +// This should probably be replaced with addrs.Function everywhere +const CoreNamespace = addrs.FunctionNamespaceCore + "::" // Functions returns the set of functions that should be used to when evaluating // expressions in the receiving scope. @@ -204,10 +206,6 @@ func (s *Scope) Functions() map[string]function.Function { for _, name := range coreNames { s.funcs[CoreNamespace+name] = s.funcs[name] } - - for name, f := range s.ProviderFunctions { - s.funcs[name] = f - } } s.funcsLock.Unlock() diff --git a/internal/lang/references.go b/internal/lang/references.go index 99ddfd5b48..6a43f8acde 100644 --- a/internal/lang/references.go +++ b/internal/lang/references.go @@ -72,7 +72,9 @@ func ReferencesInBlock(parseRef ParseRef, body hcl.Body, schema *configschema.Bl // in a better position to test this due to having mock providers etc // available. traversals := blocktoattr.ExpandedVariables(body, schema) - return References(parseRef, traversals) + funcs := filterProviderFunctions(blocktoattr.ExpandedFunctions(body, schema)) + + return References(parseRef, append(traversals, funcs...)) } // ReferencesInExpr is a helper wrapper around References that first searches @@ -83,5 +85,24 @@ func ReferencesInExpr(parseRef ParseRef, expr hcl.Expression) ([]*addrs.Referenc return nil, nil } traversals := expr.Variables() + if fexpr, ok := expr.(hcl.ExpressionWithFunctions); ok { + funcs := filterProviderFunctions(fexpr.Functions()) + traversals = append(traversals, funcs...) + } return References(parseRef, traversals) } + +func filterProviderFunctions(funcs []hcl.Traversal) []hcl.Traversal { + pfuncs := make([]hcl.Traversal, 0, len(funcs)) + for _, fn := range funcs { + if len(fn) == 0 { + continue + } + if root, ok := fn[0].(hcl.TraverseRoot); ok { + if addrs.ParseFunction(root.Name).IsNamespace(addrs.FunctionNamespaceProvider) { + pfuncs = append(pfuncs, fn) + } + } + } + return pfuncs +} diff --git a/internal/lang/scope.go b/internal/lang/scope.go index f305a52cc0..60950c9d57 100644 --- a/internal/lang/scope.go +++ b/internal/lang/scope.go @@ -72,10 +72,11 @@ type Scope struct { // either have been generated during this operation or read from the plan. PlanTimestamp time.Time - ProviderFunctions map[string]function.Function - ProviderNames map[string]addrs.Provider + ProviderFunctions ProviderFunction } +type ProviderFunction func(addrs.ProviderFunction, tfdiags.SourceRange) (*function.Function, tfdiags.Diagnostics) + // SetActiveExperiments allows a caller to declare that a set of experiments // is active for the module that the receiving Scope belongs to, which might // then cause the scope to activate some additional experimental behaviors. diff --git a/internal/legacy/tofu/provider_mock.go b/internal/legacy/tofu/provider_mock.go index b8c773fdf9..0e085e8173 100644 --- a/internal/legacy/tofu/provider_mock.go +++ b/internal/legacy/tofu/provider_mock.go @@ -362,6 +362,10 @@ func (p *MockProvider) ReadDataSource(r providers.ReadDataSourceRequest) provide return p.ReadDataSourceResponse } +func (p *MockProvider) GetFunctions() providers.GetFunctionsResponse { + panic("Not Implemented") +} + func (p *MockProvider) CallFunction(r providers.CallFunctionRequest) providers.CallFunctionResponse { panic("Not Implemented") } diff --git a/internal/plugin/grpc_provider.go b/internal/plugin/grpc_provider.go index 739e2bdd7d..b2b231c825 100644 --- a/internal/plugin/grpc_provider.go +++ b/internal/plugin/grpc_provider.go @@ -693,6 +693,26 @@ func (p *GRPCProvider) ReadDataSource(r providers.ReadDataSourceRequest) (resp p return resp } +func (p *GRPCProvider) GetFunctions() (resp providers.GetFunctionsResponse) { + logger.Trace("GRPCProvider: GetFunctions") + + protoReq := &proto.GetFunctions_Request{} + + protoResp, err := p.client.GetFunctions(p.ctx, protoReq) + if err != nil { + resp.Diagnostics = resp.Diagnostics.Append(grpcErr(err)) + return resp + } + resp.Diagnostics = resp.Diagnostics.Append(convert.ProtoToDiagnostics(protoResp.Diagnostics)) + resp.Functions = make(map[string]providers.FunctionSpec) + + for name, fn := range protoResp.Functions { + resp.Functions[name] = convert.ProtoToFunctionSpec(fn) + } + + return resp +} + func (p *GRPCProvider) CallFunction(r providers.CallFunctionRequest) (resp providers.CallFunctionResponse) { logger.Trace("GRPCProvider: CallFunction") @@ -705,9 +725,18 @@ func (p *GRPCProvider) CallFunction(r providers.CallFunctionRequest) (resp provi spec, ok := schema.Functions[r.Name] if !ok { - // This should be unreachable - resp.Error = fmt.Errorf("invalid CallFunctionRequest: function %s not defined in provider schema", r.Name) - return resp + funcs := p.GetFunctions() + if funcs.Diagnostics.HasErrors() { + // This should be unreachable + resp.Error = funcs.Diagnostics.Err() + return resp + } + spec, ok = funcs.Functions[r.Name] + if !ok { + // This should be unreachable + resp.Error = fmt.Errorf("invalid CallFunctionRequest: function %s not defined in provider schema", r.Name) + return resp + } } protoReq := &proto.CallFunction_Request{ diff --git a/internal/plugin6/grpc_provider.go b/internal/plugin6/grpc_provider.go index 21a1eb4d6f..1352323e20 100644 --- a/internal/plugin6/grpc_provider.go +++ b/internal/plugin6/grpc_provider.go @@ -682,6 +682,26 @@ func (p *GRPCProvider) ReadDataSource(r providers.ReadDataSourceRequest) (resp p return resp } +func (p *GRPCProvider) GetFunctions() (resp providers.GetFunctionsResponse) { + logger.Trace("GRPCProvider6: GetFunctions") + + protoReq := &proto6.GetFunctions_Request{} + + protoResp, err := p.client.GetFunctions(p.ctx, protoReq) + if err != nil { + resp.Diagnostics = resp.Diagnostics.Append(grpcErr(err)) + return resp + } + resp.Diagnostics = resp.Diagnostics.Append(convert.ProtoToDiagnostics(protoResp.Diagnostics)) + resp.Functions = make(map[string]providers.FunctionSpec) + + for name, fn := range protoResp.Functions { + resp.Functions[name] = convert.ProtoToFunctionSpec(fn) + } + + return resp +} + func (p *GRPCProvider) CallFunction(r providers.CallFunctionRequest) (resp providers.CallFunctionResponse) { logger.Trace("GRPCProvider6: CallFunction") @@ -694,9 +714,18 @@ func (p *GRPCProvider) CallFunction(r providers.CallFunctionRequest) (resp provi spec, ok := schema.Functions[r.Name] if !ok { - // This should be unreachable - resp.Error = fmt.Errorf("invalid CallFunctionRequest: function %s not defined in provider schema", r.Name) - return resp + funcs := p.GetFunctions() + if funcs.Diagnostics.HasErrors() { + // This should be unreachable + resp.Error = funcs.Diagnostics.Err() + return resp + } + spec, ok = funcs.Functions[r.Name] + if !ok { + // This should be unreachable + resp.Error = fmt.Errorf("invalid CallFunctionRequest: function %s not defined in provider schema", r.Name) + return resp + } } protoReq := &proto6.CallFunction_Request{ diff --git a/internal/provider-simple-v6/provider.go b/internal/provider-simple-v6/provider.go index d0868a4b55..00e83b2a57 100644 --- a/internal/provider-simple-v6/provider.go +++ b/internal/provider-simple-v6/provider.go @@ -147,6 +147,10 @@ func (s simple) ReadDataSource(req providers.ReadDataSourceRequest) (resp provid return resp } +func (s simple) GetFunctions() providers.GetFunctionsResponse { + panic("Not Implemented") +} + func (s simple) CallFunction(r providers.CallFunctionRequest) providers.CallFunctionResponse { panic("Not Implemented") } diff --git a/internal/provider-simple/provider.go b/internal/provider-simple/provider.go index 9b7757a94f..b75ad07771 100644 --- a/internal/provider-simple/provider.go +++ b/internal/provider-simple/provider.go @@ -138,6 +138,10 @@ func (s simple) ReadDataSource(req providers.ReadDataSourceRequest) (resp provid return resp } +func (s simple) GetFunctions() providers.GetFunctionsResponse { + panic("Not Implemented") +} + func (s simple) CallFunction(r providers.CallFunctionRequest) providers.CallFunctionResponse { panic("Not Implemented") } diff --git a/internal/providers/provider.go b/internal/providers/provider.go index ff7c3156b7..7deb2d983b 100644 --- a/internal/providers/provider.go +++ b/internal/providers/provider.go @@ -77,9 +77,9 @@ type Interface interface { // ReadDataSource returns the data source's current state. ReadDataSource(ReadDataSourceRequest) ReadDataSourceResponse - // GetFunctions not yet implemented or used at this stage as it is not required. - // tofu queries a full set of provider schemas early on in the process which contain - // the required information. + // GetFunctions returns a full list of functions defined in this provider. It should be a super + // set of the functions returned in GetProviderSchema() + GetFunctions() GetFunctionsResponse // CallFunction requests that the given function is called and response returned. CallFunction(CallFunctionRequest) CallFunctionResponse @@ -475,6 +475,12 @@ type ReadDataSourceResponse struct { Diagnostics tfdiags.Diagnostics } +type GetFunctionsResponse struct { + Functions map[string]FunctionSpec + + Diagnostics tfdiags.Diagnostics +} + type CallFunctionRequest struct { Name string Arguments []cty.Value diff --git a/internal/tfdiags/hcl_test.go b/internal/tfdiags/hcl_test.go index 0f8ffe855c..a851cd9a7c 100644 --- a/internal/tfdiags/hcl_test.go +++ b/internal/tfdiags/hcl_test.go @@ -99,6 +99,10 @@ func (e *fakeHCLExpression) Variables() []hcl.Traversal { return nil } +func (e *fakeHCLExpression) Functions() []hcl.Traversal { + return nil +} + func (e *fakeHCLExpression) Value(ctx *hcl.EvalContext) (cty.Value, hcl.Diagnostics) { return cty.DynamicVal, nil } diff --git a/internal/tofu/context_functions.go b/internal/tofu/context_functions.go index 48e04f789a..78ca6f54c4 100644 --- a/internal/tofu/context_functions.go +++ b/internal/tofu/context_functions.go @@ -3,55 +3,131 @@ package tofu import ( "errors" "fmt" - "log" - "sync" + "github.com/hashicorp/hcl/v2" "github.com/opentofu/opentofu/internal/addrs" + "github.com/opentofu/opentofu/internal/configs" "github.com/opentofu/opentofu/internal/providers" + "github.com/opentofu/opentofu/internal/tfdiags" "github.com/zclconf/go-cty/cty" "github.com/zclconf/go-cty/cty/function" ) -// Lazily creates a single instance of a provider for repeated use. -// Concurrency safe -func lazyProviderInstance(addr addrs.Provider, factory providers.Factory) providers.Factory { - var provider providers.Interface - var providerLock sync.Mutex - var err error +// This builds a provider function using an EvalContext and some additional information +// This is split out of BuiltinEvalContext for testing +func evalContextProviderFunction(ctx EvalContext, mc *configs.Config, op walkOperation, pf addrs.ProviderFunction, rng tfdiags.SourceRange) (*function.Function, tfdiags.Diagnostics) { + var diags tfdiags.Diagnostics - return func() (providers.Interface, error) { - providerLock.Lock() - defer providerLock.Unlock() + pr, ok := mc.Module.ProviderRequirements.RequiredProviders[pf.ProviderName] + if !ok { + return nil, diags.Append(&hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Unknown function provider", + Detail: fmt.Sprintf("Provider %q does not exist within the required_providers of this module", pf.ProviderName), + Subject: rng.ToHCL().Ptr(), + }) + } + // Very similar to transform_provider.go + absPc := addrs.AbsProviderConfig{ + Provider: pr.Type, + Module: mc.Path, + Alias: pf.ProviderAlias, + } + + provider := ctx.Provider(absPc) + + if provider == nil { + // Configured provider (NodeApplyableProvider) not required via transform_provider.go. Instead we should use the unconfigured instance (NodeEvalableProvider) in the root. + + // Make sure the alias is valid + validAlias := pf.ProviderAlias == "" + if !validAlias { + for _, alias := range pr.Aliases { + if alias.Alias == pf.ProviderAlias { + validAlias = true + break + } + } + if !validAlias { + return nil, diags.Append(&hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Unknown function provider", + Detail: fmt.Sprintf("No provider instance %q with alias %q", pf.ProviderName, pf.ProviderAlias), + Subject: rng.ToHCL().Ptr(), + }) + } + } + + provider = ctx.Provider(addrs.AbsProviderConfig{Provider: pr.Type}) if provider == nil { - log.Printf("[TRACE] tofu.contextFunctions: Initializing function provider %q", addr) - provider, err = factory() + // This should not be possible + return nil, diags.Append(&hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "BUG: Uninitialized function provider", + Detail: fmt.Sprintf("Provider %q has not yet been initialized", absPc.String()), + Subject: rng.ToHCL().Ptr(), + }) } - return provider, err } -} -// Loop through all functions specified and build a map of name -> function. -// All functions will use the same lazily initialized provider instance. -// This instance will run until the application is terminated. -func providerFunctions(addr addrs.Provider, funcSpecs map[string]providers.FunctionSpec, factory providers.Factory) map[string]function.Function { - lazy := lazyProviderInstance(addr, factory) - - functions := make(map[string]function.Function) - for name, spec := range funcSpecs { - log.Printf("[TRACE] tofu.contextFunctions: Registering function %q in provider type %q", name, addr) - if _, ok := functions[name]; ok { - panic(fmt.Sprintf("broken provider %q: multiple functions registered under name %q", addr, name)) + // First try to look up the function from provider schema + schema := provider.GetProviderSchema() + if schema.Diagnostics.HasErrors() { + return nil, schema.Diagnostics + } + spec, ok := schema.Functions[pf.Function] + if !ok { + // During the validate operation, providers are not configured and therefore won't provide + // a comprehensive GetFunctions list + // Validate is built around unknown values already, we can stub in a placeholder + if op == walkValidate { + // Configured provider functions are not available during validate + fn := function.New(&function.Spec{ + Description: "Validate Placeholder", + VarParam: &function.Parameter{ + Type: cty.DynamicPseudoType, + AllowNull: true, + AllowUnknown: true, + AllowDynamicType: true, + AllowMarked: false, + }, + Type: function.StaticReturnType(cty.DynamicPseudoType), + Impl: func(args []cty.Value, retType cty.Type) (cty.Value, error) { + return cty.UnknownVal(cty.DynamicPseudoType), nil + }, + }) + return &fn, nil + } + + // The provider may be configured and present additional functions via GetFunctions + specs := provider.GetFunctions() + if specs.Diagnostics.HasErrors() { + return nil, specs.Diagnostics + } + + // If the function isn't in the custom GetFunctions list, it must be undefined + spec, ok = specs.Functions[pf.Function] + if !ok { + return nil, diags.Append(&hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Function not found in provider", + Detail: fmt.Sprintf("Function %q was not registered by provider %q", pf.Function, absPc.String()), + Subject: rng.ToHCL().Ptr(), + }) } - functions[name] = providerFunction(name, spec, lazy) } - return functions + + fn := providerFunction(pf.Function, spec, provider) + + return &fn, nil + } // Turn a provider function spec into a cty callable function // This will use the instance factory to get a provider to support the // function call. -func providerFunction(name string, spec providers.FunctionSpec, instance providers.Factory) function.Function { +func providerFunction(name string, spec providers.FunctionSpec, provider providers.Interface) function.Function { params := make([]function.Parameter, len(spec.Parameters)) for i, param := range spec.Parameters { params[i] = providerFunctionParameter(param) @@ -64,11 +140,6 @@ func providerFunction(name string, spec providers.FunctionSpec, instance provide } impl := func(args []cty.Value, retType cty.Type) (cty.Value, error) { - provider, err := instance() - if err != nil { - // Incredibly unlikely - return cty.UnknownVal(retType), err - } resp := provider.CallFunction(providers.CallFunctionRequest{ Name: name, Arguments: args, diff --git a/internal/tofu/context_functions_test.go b/internal/tofu/context_functions_test.go index 8428d58a1e..113d97c250 100644 --- a/internal/tofu/context_functions_test.go +++ b/internal/tofu/context_functions_test.go @@ -1,16 +1,18 @@ package tofu import ( - "fmt" "strings" "testing" "github.com/hashicorp/hcl/v2" "github.com/hashicorp/hcl/v2/hclsyntax" "github.com/opentofu/opentofu/internal/addrs" + "github.com/opentofu/opentofu/internal/configs" "github.com/opentofu/opentofu/internal/lang/marks" "github.com/opentofu/opentofu/internal/providers" + "github.com/opentofu/opentofu/internal/tfdiags" "github.com/zclconf/go-cty/cty" + "github.com/zclconf/go-cty/cty/function" ) func TestFunctions(t *testing.T) { @@ -105,64 +107,93 @@ func TestFunctions(t *testing.T) { return resp } - // Initial call to getSchema - expectProviderInit := true - - mockFactory := func() (providers.Interface, error) { - if !expectProviderInit { - return nil, fmt.Errorf("Unexpected call to provider init!") - } - expectProviderInit = false - return mockProvider, nil + mockProvider.GetFunctionsFn = func() (resp providers.GetFunctionsResponse) { + resp.Functions = mockProvider.GetProviderSchemaResponse.Functions + return resp } addr := addrs.NewDefaultProvider("mock") - plugins := newContextPluginsForTest(map[addrs.Provider]providers.Factory{ - addr: mockFactory, - }, t) + rng := tfdiags.SourceRange{} + providerFunc := func(fn string) addrs.ProviderFunction { + pf, _ := addrs.ParseFunction(fn).AsProviderFunction() + return pf + } - t.Run("empty names map", func(t *testing.T) { - res := plugins.Functions(map[string]addrs.Provider{}) - if len(res.ProviderNames) != 0 { - t.Error("did not expect any names") - } - if len(res.Functions) != 0 { - t.Error("did not expect any functions") - } - }) + mockCtx := new(MockEvalContext) + cfg := &configs.Config{ + Module: &configs.Module{ + ProviderRequirements: &configs.RequiredProviders{ + RequiredProviders: map[string]*configs.RequiredProvider{ + "mockname": &configs.RequiredProvider{ + Name: "mock", + Type: addr, + }, + }, + }, + }, + } - t.Run("broken names map", func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("Expected panic due to broken configuration") - } - }() + // Provider missing + _, diags := evalContextProviderFunction(mockCtx, cfg, walkValidate, providerFunc("provider::invalid::unknown"), rng) + if !diags.HasErrors() { + t.Fatal("expected unknown function provider") + } + if diags.Err().Error() != `Unknown function provider: Provider "invalid" does not exist within the required_providers of this module` { + t.Fatal(diags.Err()) + } - res := plugins.Functions(map[string]addrs.Provider{ - "borky": addrs.NewDefaultProvider("my_borky"), - }) - if len(res.ProviderNames) != 0 { - t.Error("did not expect any names") - } - if len(res.Functions) != 0 { - t.Error("did not expect any functions") - } - }) + // Provider not initialized + _, diags = evalContextProviderFunction(mockCtx, cfg, walkValidate, providerFunc("provider::mockname::missing"), rng) + if !diags.HasErrors() { + t.Fatal("expected unknown function provider") + } + if diags.Err().Error() != `BUG: Uninitialized function provider: Provider "provider[\"registry.opentofu.org/hashicorp/mock\"]" has not yet been initialized` { + t.Fatal(diags.Err()) + } - res := plugins.Functions(map[string]addrs.Provider{ - "mockname": addr, - }) - if res.ProviderNames["mockname"] != addr { - t.Errorf("expected names %q, got %q", addr, res.ProviderNames["mockname"]) + // "initialize" provider + mockCtx.ProviderProvider = mockProvider + + // Function missing (validate) + mockProvider.GetFunctionsCalled = false + _, diags = evalContextProviderFunction(mockCtx, cfg, walkValidate, providerFunc("provider::mockname::missing"), rng) + if diags.HasErrors() { + t.Fatal(diags.Err()) + } + if mockProvider.GetFunctionsCalled { + t.Fatal("expected GetFunctions NOT to be called since it's not initialized") + } + + // Function missing (Non-validate) + mockProvider.GetFunctionsCalled = false + _, diags = evalContextProviderFunction(mockCtx, cfg, walkPlan, providerFunc("provider::mockname::missing"), rng) + if !diags.HasErrors() { + t.Fatal("expected unknown function") + } + if diags.Err().Error() != `Function not found in provider: Function "missing" was not registered by provider "provider[\"registry.opentofu.org/hashicorp/mock\"]"` { + t.Fatal(diags.Err()) + } + if !mockProvider.GetFunctionsCalled { + t.Fatal("expected GetFunctions to be called") } ctx := &hcl.EvalContext{ - Functions: res.Functions, + Functions: map[string]function.Function{}, Variables: map[string]cty.Value{ "unknown_value": cty.UnknownVal(cty.String), "sensitive_value": cty.StringVal("sensitive!").Mark(marks.Sensitive), }, } + + // Load functions into ctx + for _, fn := range []string{"echo", "concat", "coalesce", "unknown_param", "error_param"} { + pf := providerFunc("provider::mockname::" + fn) + impl, diags := evalContextProviderFunction(mockCtx, cfg, walkPlan, pf, rng) + if diags.HasErrors() { + t.Fatal(diags.Err()) + } + ctx.Functions[pf.String()] = *impl + } evaluate := func(exprStr string) (cty.Value, hcl.Diagnostics) { expr, diags := hclsyntax.ParseExpression([]byte(exprStr), "exprtest", hcl.InitialPos) if diags.HasErrors() { @@ -203,22 +234,14 @@ func TestFunctions(t *testing.T) { // Actually test the function implementation - // Do this a few times but only expect a single init() - expectProviderInit = true - for i := 0; i < 5; i++ { - t.Log("Checking valid argument") + t.Log("Checking valid argument") - val, diags = evaluate(`provider::mockname::echo("hello functions!")`) - if diags.HasErrors() { - t.Error(diags.Error()) - } - if !val.RawEquals(cty.StringVal("hello functions!")) { - t.Error(val.AsString()) - } - - if expectProviderInit { - t.Error("Expected provider init to have been called") - } + val, diags = evaluate(`provider::mockname::echo("hello functions!")`) + if diags.HasErrors() { + t.Error(diags.Error()) + } + if !val.RawEquals(cty.StringVal("hello functions!")) { + t.Error(val.AsString()) } t.Log("Checking sensitive argument") diff --git a/internal/tofu/context_plugins.go b/internal/tofu/context_plugins.go index 719d06d282..14dd54f6ec 100644 --- a/internal/tofu/context_plugins.go +++ b/internal/tofu/context_plugins.go @@ -13,7 +13,6 @@ import ( "github.com/opentofu/opentofu/internal/configs/configschema" "github.com/opentofu/opentofu/internal/providers" "github.com/opentofu/opentofu/internal/provisioners" - "github.com/zclconf/go-cty/cty/function" ) // contextPlugins represents a library of available plugins (providers and @@ -22,50 +21,14 @@ import ( // about the providers for performance reasons. type contextPlugins struct { providerFactories map[addrs.Provider]providers.Factory - providerFunctions map[addrs.Provider]map[string]function.Function provisionerFactories map[string]provisioners.Factory } func newContextPlugins(providerFactories map[addrs.Provider]providers.Factory, provisionerFactories map[string]provisioners.Factory) (*contextPlugins, error) { - ret := &contextPlugins{ + return &contextPlugins{ providerFactories: providerFactories, provisionerFactories: provisionerFactories, - } - - // This is a bit convoluted as we need to use the ProviderSchema function call below to - // validate and initialize the provider schemas. Long term the whole provider abstraction - // needs to be re-thought. - var err error - ret.providerFunctions, err = ret.buildProviderFunctions() - if err != nil { - return nil, err - } - return ret, nil -} - -// Loop through all of the providerFactories and build a map of addr -> functions -// As a side effect, this initialzes the schema cache if not already initialized, with the proper validation path. -func (cp *contextPlugins) buildProviderFunctions() (map[addrs.Provider]map[string]function.Function, error) { - funcs := make(map[addrs.Provider]map[string]function.Function) - - // Pull all functions out of given providers - for addr, factory := range cp.providerFactories { - addr := addr - factory := factory - - // Before functions, the provider schemas were already pre-loaded and cached. That initial caching - // has been moved here. When the provider abstraction layers are refactored, this could instead - // expose and use provider.GetFunctions instead of needing to load and cache the whole schema. - // However, at the time of writing there is no benefit to defer caching these schemas in code - // paths which build a tofu.Context. - schema, err := cp.ProviderSchema(addr) - if err != nil { - return nil, err - } - - funcs[addr] = providerFunctions(addr, schema.Functions, factory) - } - return funcs, nil + }, nil // TODO remove error from this function call! } func (cp *contextPlugins) HasProvider(addr addrs.Provider) bool { @@ -216,26 +179,3 @@ func (cp *contextPlugins) ProvisionerSchema(typ string) (*configschema.Block, er return resp.Provisioner, nil } - -type ProviderFunctions struct { - ProviderNames map[string]addrs.Provider - Functions map[string]function.Function -} - -// Functions provides a map of provider:::: for a given provider type. -// All providers of a given type use the same functions and provider instance and -// additional names do not incur any performance penalty. -func (cp *contextPlugins) Functions(names map[string]addrs.Provider) *ProviderFunctions { - providerFuncs := &ProviderFunctions{ - ProviderNames: names, - Functions: make(map[string]function.Function), - } - - for name, addr := range names { - funcs := cp.providerFunctions[addr] - for fn_name, fn := range funcs { - providerFuncs.Functions[fmt.Sprintf("provider::%s::%s", name, fn_name)] = fn - } - } - return providerFuncs -} diff --git a/internal/tofu/eval_context_builtin.go b/internal/tofu/eval_context_builtin.go index 8c976334f1..539a233b7e 100644 --- a/internal/tofu/eval_context_builtin.go +++ b/internal/tofu/eval_context_builtin.go @@ -13,6 +13,9 @@ import ( "github.com/hashicorp/hcl/v2" "github.com/hashicorp/hcl/v2/hclsyntax" + "github.com/zclconf/go-cty/cty" + "github.com/zclconf/go-cty/cty/function" + "github.com/opentofu/opentofu/internal/addrs" "github.com/opentofu/opentofu/internal/checks" "github.com/opentofu/opentofu/internal/configs/configschema" @@ -27,7 +30,6 @@ import ( "github.com/opentofu/opentofu/internal/states" "github.com/opentofu/opentofu/internal/tfdiags" "github.com/opentofu/opentofu/version" - "github.com/zclconf/go-cty/cty" ) // BuiltinEvalContext is an EvalContext implementation that is used by @@ -70,8 +72,6 @@ type BuiltinEvalContext struct { ProviderLock *sync.Mutex ProvisionerCache map[string]provisioners.Interface ProvisionerLock *sync.Mutex - FunctionCache *ProviderFunctions - FunctionLock sync.Mutex ChangesValue *plans.ChangesSync StateValue *states.SyncState ChecksValue *checks.State @@ -90,8 +90,6 @@ func (ctx *BuiltinEvalContext) WithPath(path addrs.ModuleInstance) EvalContext { newCtx := *ctx newCtx.pathSet = true newCtx.PathValue = path - newCtx.FunctionCache = nil - newCtx.FunctionLock = sync.Mutex{} return &newCtx } @@ -129,18 +127,16 @@ func (ctx *BuiltinEvalContext) Input() UIInput { } func (ctx *BuiltinEvalContext) InitProvider(addr addrs.AbsProviderConfig) (providers.Interface, error) { - // If we already initialized, it is an error - if p := ctx.Provider(addr); p != nil { - return nil, fmt.Errorf("%s is already initialized", addr) - } - - // Warning: make sure to acquire these locks AFTER the call to Provider - // above, since it also acquires locks. ctx.ProviderLock.Lock() defer ctx.ProviderLock.Unlock() key := addr.String() + // If we have already initialized, it is an error + if _, ok := ctx.ProviderCache[key]; ok { + return nil, fmt.Errorf("%s is already initialized", addr) + } + p, err := ctx.Plugins.NewProviderInstance(addr.Provider) if err != nil { return nil, err @@ -526,20 +522,9 @@ func (ctx *BuiltinEvalContext) EvaluationScope(self addrs.Referenceable, source return ctx.Evaluator.Scope(data, self, source, nil) } - ctx.FunctionLock.Lock() - defer ctx.FunctionLock.Unlock() - if ctx.FunctionCache == nil { - names := make(map[string]addrs.Provider) - - // Providers must exist within required_providers to register their functions - for name, provider := range mc.Module.ProviderRequirements.RequiredProviders { - // Functions are only registered under their name, not their type name - names[name] = provider.Type - } - - ctx.FunctionCache = ctx.Plugins.Functions(names) - } - scope := ctx.Evaluator.Scope(data, self, source, ctx.FunctionCache) + scope := ctx.Evaluator.Scope(data, self, source, func(pf addrs.ProviderFunction, rng tfdiags.SourceRange) (*function.Function, tfdiags.Diagnostics) { + return evalContextProviderFunction(ctx, mc, ctx.Evaluator.Operation, pf, rng) + }) scope.SetActiveExperiments(mc.Module.ActiveExperiments) return scope diff --git a/internal/tofu/evaluate.go b/internal/tofu/evaluate.go index 4ef183aff1..13718c0147 100644 --- a/internal/tofu/evaluate.go +++ b/internal/tofu/evaluate.go @@ -76,21 +76,16 @@ type Evaluator struct { // If the "self" argument is nil then the "self" object is not available // in evaluated expressions. Otherwise, it behaves as an alias for the given // address. -func (e *Evaluator) Scope(data lang.Data, self addrs.Referenceable, source addrs.Referenceable, functions *ProviderFunctions) *lang.Scope { - if functions == nil { - functions = new(ProviderFunctions) - } +func (e *Evaluator) Scope(data lang.Data, self addrs.Referenceable, source addrs.Referenceable, functions lang.ProviderFunction) *lang.Scope { return &lang.Scope{ - Data: data, - ParseRef: addrs.ParseRef, - SelfAddr: self, - SourceAddr: source, - PureOnly: e.Operation != walkApply && e.Operation != walkDestroy && e.Operation != walkEval, - BaseDir: ".", // Always current working directory for now. - PlanTimestamp: e.PlanTimestamp, - // Can't pass the object directly as it would cause an import loop - ProviderNames: functions.ProviderNames, - ProviderFunctions: functions.Functions, + Data: data, + ParseRef: addrs.ParseRef, + SelfAddr: self, + SourceAddr: source, + PureOnly: e.Operation != walkApply && e.Operation != walkDestroy && e.Operation != walkEval, + BaseDir: ".", // Always current working directory for now. + PlanTimestamp: e.PlanTimestamp, + ProviderFunctions: functions, } } diff --git a/internal/tofu/graph_builder_apply.go b/internal/tofu/graph_builder_apply.go index c720f56909..d24d2a07bf 100644 --- a/internal/tofu/graph_builder_apply.go +++ b/internal/tofu/graph_builder_apply.go @@ -146,6 +146,12 @@ func (b *ApplyGraphBuilder) Steps() []GraphTransformer { // analyze the configuration to find references. &AttachSchemaTransformer{Plugins: b.Plugins, Config: b.Config}, + // After schema transformer, we can add function references + &ProviderFunctionTransformer{Config: b.Config}, + + // Remove unused providers and proxies + &PruneProviderTransformer{}, + // Create expansion nodes for all of the module calls. This must // come after all other transformers that create nodes representing // objects that can belong to modules. diff --git a/internal/tofu/graph_builder_eval.go b/internal/tofu/graph_builder_eval.go index 7dcc11d780..3066baaa51 100644 --- a/internal/tofu/graph_builder_eval.go +++ b/internal/tofu/graph_builder_eval.go @@ -89,6 +89,12 @@ func (b *EvalGraphBuilder) Steps() []GraphTransformer { // analyze the configuration to find references. &AttachSchemaTransformer{Plugins: b.Plugins, Config: b.Config}, + // After schema transformer, we can add function references + &ProviderFunctionTransformer{Config: b.Config}, + + // Remove unused providers and proxies + &PruneProviderTransformer{}, + // Create expansion nodes for all of the module calls. This must // come after all other transformers that create nodes representing // objects that can belong to modules. diff --git a/internal/tofu/graph_builder_plan.go b/internal/tofu/graph_builder_plan.go index 905450780c..0434d0c5f9 100644 --- a/internal/tofu/graph_builder_plan.go +++ b/internal/tofu/graph_builder_plan.go @@ -199,6 +199,12 @@ func (b *PlanGraphBuilder) Steps() []GraphTransformer { // analyze the configuration to find references. &AttachSchemaTransformer{Plugins: b.Plugins, Config: b.Config}, + // After schema transformer, we can add function references + &ProviderFunctionTransformer{Config: b.Config}, + + // Remove unused providers and proxies + &PruneProviderTransformer{}, + // Create expansion nodes for all of the module calls. This must // come after all other transformers that create nodes representing // objects that can belong to modules. diff --git a/internal/tofu/node_root_variable_test.go b/internal/tofu/node_root_variable_test.go index abdfbe23ec..66990eb7fa 100644 --- a/internal/tofu/node_root_variable_test.go +++ b/internal/tofu/node_root_variable_test.go @@ -179,6 +179,10 @@ func (f fakeHCLExpressionFunc) Variables() []hcl.Traversal { return nil } +func (f fakeHCLExpressionFunc) Functions() []hcl.Traversal { + return nil +} + func (f fakeHCLExpressionFunc) Range() hcl.Range { return hcl.Range{ Filename: "fake", diff --git a/internal/tofu/provider_mock.go b/internal/tofu/provider_mock.go index 220fa998fa..5693a3b98d 100644 --- a/internal/tofu/provider_mock.go +++ b/internal/tofu/provider_mock.go @@ -87,6 +87,10 @@ type MockProvider struct { ReadDataSourceRequest providers.ReadDataSourceRequest ReadDataSourceFn func(providers.ReadDataSourceRequest) providers.ReadDataSourceResponse + GetFunctionsCalled bool + GetFunctionsResponse *providers.GetFunctionsResponse + GetFunctionsFn func() providers.GetFunctionsResponse + CallFunctionCalled bool CallFunctionResponse *providers.CallFunctionResponse CallFunctionRequest providers.CallFunctionRequest @@ -521,6 +525,22 @@ func (p *MockProvider) ReadDataSource(r providers.ReadDataSourceRequest) (resp p return resp } +func (p *MockProvider) GetFunctions() (resp providers.GetFunctionsResponse) { + p.Lock() + defer p.Unlock() + + p.GetFunctionsCalled = true + + if p.GetFunctionsFn != nil { + return p.GetFunctionsFn() + } + + if p.GetFunctionsResponse != nil { + resp = *p.GetFunctionsResponse + } + return resp +} + func (p *MockProvider) CallFunction(r providers.CallFunctionRequest) (resp providers.CallFunctionResponse) { p.Lock() defer p.Unlock() diff --git a/internal/tofu/transform_provider.go b/internal/tofu/transform_provider.go index db0669bd19..b36ab2f8ee 100644 --- a/internal/tofu/transform_provider.go +++ b/internal/tofu/transform_provider.go @@ -32,8 +32,11 @@ func transformProviders(concrete ConcreteProviderNodeFunc, config *configs.Confi &ProviderTransformer{ Config: config, }, + // The following comment shows what must be added to the transformer list after the schema transformer + // After schema transformer, we can add function references + // &ProviderFunctionTransformer{Config: config}, // Remove unused providers and proxies - &PruneProviderTransformer{}, + // &PruneProviderTransformer{}, ) } @@ -180,6 +183,7 @@ func (t *ProviderTransformer) Transform(g *Graph) error { } if target != nil { + // Providers with configuration will already exist within the graph and can be directly referenced log.Printf("[TRACE] ProviderTransformer: exact match for %s serving %s", p, dag.VertexName(v)) } @@ -245,6 +249,130 @@ func (t *ProviderTransformer) Transform(g *Graph) error { return diags.Err() } +// ProviderFunctionTransformer is a GraphTransformer that maps nodes which reference functions to providers +// within the graph. This will error if there are any provider functions that don't map to known providers. +type ProviderFunctionTransformer struct { + Config *configs.Config +} + +func (t *ProviderFunctionTransformer) Transform(g *Graph) error { + var diags tfdiags.Diagnostics + + if t.Config == nil { + // This is probably a test case, inherited from ProviderTransformer + log.Printf("[WARN] Skipping provider function transformer due to missing config") + return nil + } + + // Locate all providers in the graph + providers := providerVertexMap(g) + + type providerReference struct { + path string + name string + alias string + } + // LuT of provider reference -> provider vertex + providerReferences := make(map[providerReference]dag.Vertex) + + for _, v := range g.Vertices() { + // Provider function references + if nr, ok := v.(GraphNodeReferencer); ok && t.Config != nil { + for _, ref := range nr.References() { + if pf, ok := ref.Subject.(addrs.ProviderFunction); ok { + key := providerReference{ + path: nr.ModulePath().String(), + name: pf.ProviderName, + alias: pf.ProviderAlias, + } + + // We already know about this provider and can link directly + if provider, ok := providerReferences[key]; ok { + // Is it worth skipping if we have already connected this provider? + g.Connect(dag.BasicEdge(v, provider)) + continue + } + + // Find the config that this node belongs to + mc := t.Config.Descendent(nr.ModulePath()) + if mc == nil { + // I don't think this is possible + diags = diags.Append(&hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Unknown Descendent Module", + Detail: nr.ModulePath().String(), + Subject: ref.SourceRange.ToHCL().Ptr(), + }) + continue + } + + // Find the provider type from required_providers + pr, ok := mc.Module.ProviderRequirements.RequiredProviders[pf.ProviderName] + if !ok { + diags = diags.Append(&hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Unknown function provider", + Detail: fmt.Sprintf("Provider %q does not exist within the required_providers of this module", pf.ProviderName), + Subject: ref.SourceRange.ToHCL().Ptr(), + }) + continue + } + + // Build fully qualified provider address + absPc := addrs.AbsProviderConfig{ + Provider: pr.Type, + Module: nr.ModulePath(), + Alias: pf.ProviderAlias, + } + + log.Printf("[TRACE] ProviderFunctionTransformer: %s in %s is provided by %s", pf, dag.VertexName(v), absPc) + + // Lookup provider via full address + provider := providers[absPc.String()] + + if provider != nil { + // Providers with configuration will already exist within the graph and can be directly referenced + log.Printf("[TRACE] ProviderFunctionTransformer: exact match for %s serving %s", absPc, dag.VertexName(v)) + } else { + // If this provider doesn't need to be configured then we can just + // stub it out with an init-only provider node, which will just + // start up the provider and fetch its schema. + stubAddr := addrs.AbsProviderConfig{ + Module: addrs.RootModule, + Provider: absPc.Provider, + } + if provider, ok = providers[stubAddr.String()]; !ok { + stub := &NodeEvalableProvider{ + &NodeAbstractProvider{ + Addr: stubAddr, + }, + } + providers[stubAddr.String()] = stub + log.Printf("[TRACE] ProviderFunctionTransformer: creating init-only node for %s", stubAddr) + provider = stub + g.Add(provider) + } + } + + // see if this is a proxy provider pointing to another concrete config + if p, ok := provider.(*graphNodeProxyProvider); ok { + g.Remove(p) + provider = p.Target() + } + + log.Printf("[DEBUG] ProviderFunctionTransformer: %q (%T) needs %s", dag.VertexName(v), v, dag.VertexName(provider)) + g.Connect(dag.BasicEdge(v, provider)) + + // Save for future lookups + providerReferences[key] = provider + } + } + } + } + + return diags.Err() +} + // CloseProviderTransformer is a GraphTransformer that adds nodes to the // graph that will close open provider connections that aren't needed anymore. // A provider connection is not needed anymore once all depended resources @@ -279,6 +407,8 @@ func (t *CloseProviderTransformer) Transform(g *Graph) error { for _, s := range g.UpEdges(p) { if _, ok := s.(GraphNodeProviderConsumer); ok { g.Connect(dag.BasicEdge(closer, s)) + } else if _, ok := s.(GraphNodeReferencer); ok { + g.Connect(dag.BasicEdge(closer, s)) } } } diff --git a/internal/tofu/transform_provider_test.go b/internal/tofu/transform_provider_test.go index 3456ac81d6..28b63a8858 100644 --- a/internal/tofu/transform_provider_test.go +++ b/internal/tofu/transform_provider_test.go @@ -31,6 +31,30 @@ func testProviderTransformerGraph(t *testing.T, cfg *configs.Config) *Graph { return g } +// This variant exists purely for testing and can not currently include the ProviderFunctionTransformer +func testTransformProviders(concrete ConcreteProviderNodeFunc, config *configs.Config) GraphTransformer { + return GraphTransformMulti( + // Add providers from the config + &ProviderConfigTransformer{ + Config: config, + Concrete: concrete, + }, + // Add any remaining missing providers + &MissingProviderTransformer{ + Config: config, + Concrete: concrete, + }, + // Connect the providers + &ProviderTransformer{ + Config: config, + }, + // After schema transformer, we can add function references + // &ProviderFunctionTransformer{Config: config}, + // Remove unused providers and proxies + &PruneProviderTransformer{}, + ) +} + func TestProviderTransformer(t *testing.T) { mod := testModule(t, "transform-provider-basic") @@ -181,7 +205,7 @@ func TestMissingProviderTransformer_grandchildMissing(t *testing.T) { g := testProviderTransformerGraph(t, mod) { - transform := transformProviders(concrete, mod) + transform := testTransformProviders(concrete, mod) if err := transform.Transform(g); err != nil { t.Fatalf("err: %s", err) } @@ -246,7 +270,7 @@ func TestProviderConfigTransformer_parentProviders(t *testing.T) { g := testProviderTransformerGraph(t, mod) { - tf := transformProviders(concrete, mod) + tf := testTransformProviders(concrete, mod) if err := tf.Transform(g); err != nil { t.Fatalf("err: %s", err) } @@ -266,7 +290,7 @@ func TestProviderConfigTransformer_grandparentProviders(t *testing.T) { g := testProviderTransformerGraph(t, mod) { - tf := transformProviders(concrete, mod) + tf := testTransformProviders(concrete, mod) if err := tf.Transform(g); err != nil { t.Fatalf("err: %s", err) } @@ -300,7 +324,7 @@ resource "test_object" "a" { g := testProviderTransformerGraph(t, mod) { - tf := transformProviders(concrete, mod) + tf := testTransformProviders(concrete, mod) if err := tf.Transform(g); err != nil { t.Fatalf("err: %s", err) } @@ -378,7 +402,7 @@ resource "test_object" "a" { g := testProviderTransformerGraph(t, mod) { - tf := transformProviders(concrete, mod) + tf := testTransformProviders(concrete, mod) if err := tf.Transform(g); err != nil { t.Fatalf("err: %s", err) } diff --git a/internal/tofu/transform_reference.go b/internal/tofu/transform_reference.go index 715a9c65d7..7f09139372 100644 --- a/internal/tofu/transform_reference.go +++ b/internal/tofu/transform_reference.go @@ -340,6 +340,8 @@ func (m ReferenceMap) addReference(path addrs.Module, current dag.Vertex, ref *a subject = ri.ModuleCallOutput() case addrs.ModuleCallInstance: subject = ri.Call + case addrs.ProviderFunction: + return nil default: log.Printf("[INFO] ReferenceTransformer: reference not found: %q", subject) return nil @@ -433,6 +435,8 @@ func (m ReferenceMap) dataDependsOn(depender graphNodeDependsOn) []*addrs.Refere case addrs.ResourceInstance: resAddr = s.Resource r.Subject = resAddr + case addrs.ProviderFunction: + continue } if resAddr.Mode != addrs.ManagedResourceMode {