diff --git a/cmd/tofu/commands.go b/cmd/tofu/commands.go index 5e9f62de7b..e77fcc31e0 100644 --- a/cmd/tofu/commands.go +++ b/cmd/tofu/commands.go @@ -122,6 +122,8 @@ func initCommands( UnmanagedProviders: unmanagedProviders, AllowExperimentalFeatures: experimentsAreAllowed(), + + ProviderSourceLocationConfig: providerSourceLocationConfig(), } // The command list is included in the tofu -help @@ -362,9 +364,9 @@ func initCommands( }, nil }, - //----------------------------------------------------------- + // ----------------------------------------------------------- // Plumbing - //----------------------------------------------------------- + // ----------------------------------------------------------- "force-unlock": func() (cli.Command, error) { return &command.UnlockCommand{ diff --git a/cmd/tofu/provider_source.go b/cmd/tofu/provider_source.go index 36cdbf7f21..ef71968a08 100644 --- a/cmd/tofu/provider_source.go +++ b/cmd/tofu/provider_source.go @@ -12,6 +12,7 @@ import ( "net/url" "os" "path/filepath" + "strconv" "github.com/apparentlymart/go-userdirs/userdirs" "github.com/opentofu/svchost/disco" @@ -207,7 +208,7 @@ func implicitProviderSource( // local copy will take precedence. searchRules = append(searchRules, getproviders.MultiSourceSelector{ Source: getproviders.NewMemoizeSource( - getproviders.NewRegistrySource(ctx, services, newRegistryHTTPClient(ctx, registryClientConfig)), + getproviders.NewRegistrySource(ctx, services, newRegistryHTTPClient(ctx, registryClientConfig), providerSourceLocationConfig()), ), Exclude: directExcluded, }) @@ -224,7 +225,7 @@ func providerSourceForCLIConfigLocation( ) (getproviders.Source, tfdiags.Diagnostics) { if loc == cliconfig.ProviderInstallationDirect { return getproviders.NewMemoizeSource( - getproviders.NewRegistrySource(ctx, services, newRegistryHTTPClient(ctx, registryClientConfig)), + getproviders.NewRegistrySource(ctx, services, newRegistryHTTPClient(ctx, registryClientConfig), providerSourceLocationConfig()), ), nil } @@ -258,7 +259,7 @@ func providerSourceForCLIConfigLocation( // this client is not suitable for the HTTP mirror source, so we // don't use this client directly. httpTimeout := newRegistryHTTPClient(ctx, registryClientConfig).HTTPClient.Timeout - return getproviders.NewHTTPMirrorSource(ctx, url, services.CredentialsSource(), httpTimeout), nil + return getproviders.NewHTTPMirrorSource(ctx, url, services.CredentialsSource(), httpTimeout, providerSourceLocationConfig()), nil case cliconfig.ProviderInstallationOCIMirror: mappingFunc := loc.RepositoryMapping @@ -298,3 +299,34 @@ func providerDevOverrides(configs []*cliconfig.ProviderInstallation) map[addrs.P // ignore any additional configurations in here. return configs[0].DevOverrides } + +const ( + // providerDownloadRetryCountEnvName is the environment variable name used to customize + // the HTTP retry count for module downloads. + providerDownloadRetryCountEnvName = "TF_PROVIDER_DOWNLOAD_RETRY" + + providerDownloadDefaultRetry = 2 +) + +// providerDownloadRetry will attempt for requests with retryable errors, like 502 status codes +func providerDownloadRetry() int { + res := providerDownloadDefaultRetry + if v := os.Getenv(providerDownloadRetryCountEnvName); v != "" { + retry, err := strconv.Atoi(v) + if err == nil && retry > 0 { + res = retry + } + } + return res +} + +// providerSourceLocationConfig is meant to build a global configuration for the +// remote locations to download a provider from. This is built out of the +// TF_PROVIDER_DOWNLOAD_RETRY env variable and is meant to be passed through +// [getproviders.Source] all the way down to the [getproviders.PackageLocation] +// to be able to tweak the configurations of the http clients used there. +func providerSourceLocationConfig() getproviders.LocationConfig { + return getproviders.LocationConfig{ + ProviderDownloadRetries: providerDownloadRetry(), + } +} diff --git a/cmd/tofu/provider_source_test.go b/cmd/tofu/provider_source_test.go index d1f0192d35..e3dd1568bd 100644 --- a/cmd/tofu/provider_source_test.go +++ b/cmd/tofu/provider_source_test.go @@ -12,9 +12,11 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/opentofu/opentofu/internal/addrs" "github.com/opentofu/opentofu/internal/command/cliconfig" "github.com/opentofu/opentofu/internal/command/cliconfig/ociauthconfig" + "github.com/opentofu/opentofu/internal/getproviders" "github.com/opentofu/svchost/disco" ) @@ -157,3 +159,43 @@ func TestProviderSource(t *testing.T) { }) } } + +func TestConfigureProviderDownloadRetry(t *testing.T) { + tests := []struct { + name string + envVars map[string]string + + expectedConfig getproviders.LocationConfig + }{ + { + name: "when no TF_PROVIDER_DOWNLOAD_RETRY env var, default retry attempts used for provider download", + expectedConfig: getproviders.LocationConfig{ + ProviderDownloadRetries: providerDownloadDefaultRetry, + }, + }, + { + name: "when TF_PROVIDER_DOWNLOAD_RETRY env var configured, it is used provider download", + envVars: map[string]string{ + "TF_PROVIDER_DOWNLOAD_RETRY": "7", + }, + expectedConfig: getproviders.LocationConfig{ + ProviderDownloadRetries: 7, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for k, v := range tt.envVars { + t.Setenv(k, v) + } + + // Call the function under test + got := providerSourceLocationConfig() + + if diff := cmp.Diff(tt.expectedConfig, got); diff != "" { + t.Fatalf("expected no diff. got:\n%s", diff) + } + }) + } +} diff --git a/internal/command/command_test.go b/internal/command/command_test.go index 0c8418a56d..7f932d2f81 100644 --- a/internal/command/command_test.go +++ b/internal/command/command_test.go @@ -1029,7 +1029,7 @@ func testServices(t *testing.T) (services *disco.Disco, cleanup func()) { // of your test in order to shut down the test server. func testRegistrySource(t *testing.T) (source *getproviders.RegistrySource, cleanup func()) { services, close := testServices(t) - source = getproviders.NewRegistrySource(t.Context(), services, nil) + source = getproviders.NewRegistrySource(t.Context(), services, nil, getproviders.LocationConfig{ProviderDownloadRetries: 0}) return source, close } diff --git a/internal/command/meta.go b/internal/command/meta.go index 399dd7757e..63f488adbb 100644 --- a/internal/command/meta.go +++ b/internal/command/meta.go @@ -200,9 +200,9 @@ type Meta struct { // flag is set, to reinforce that experiments are not for production use. AllowExperimentalFeatures bool - //---------------------------------------------------------- + // ---------------------------------------------------------- // Protected: commands can set these - //---------------------------------------------------------- + // ---------------------------------------------------------- // pluginPath is a user defined set of directories to look for plugins. // This is set during init with the `-plugin-dir` flag, saved to a file in @@ -213,9 +213,9 @@ type Meta struct { // Override certain behavior for tests within this package testingOverrides *testingOverrides - //---------------------------------------------------------- + // ---------------------------------------------------------- // Private: do not set these - //---------------------------------------------------------- + // ---------------------------------------------------------- // configLoader is a shared configuration loader that is used by // LoadConfig and other commands that access configuration files. @@ -305,6 +305,13 @@ type Meta struct { // This helps prevent duplicate errors/warnings. rootModuleCallCache *configs.StaticModuleCall inputVariableCache map[string]backend.UnparsedVariableValue + + // Since `tofu providers lock` and `tofu providers mirror` have their own + // logic to create the source to fetch providers through, we had to + // plumb this configuration through the [Meta] type to reach that part too. + // In any other cases, this configuration is built and used directly in `realMain` + // when the providers sources are built. + ProviderSourceLocationConfig getproviders.LocationConfig } type testingOverrides struct { diff --git a/internal/command/providers_lock.go b/internal/command/providers_lock.go index 083bc8390a..eb3036d915 100644 --- a/internal/command/providers_lock.go +++ b/internal/command/providers_lock.go @@ -139,12 +139,12 @@ func (c *ProvidersLockCommand) Run(args []string) int { // this client is not suitable for the HTTP mirror source, so we // don't use this client directly. httpTimeout := c.registryHTTPClient(ctx).HTTPClient.Timeout - source = getproviders.NewHTTPMirrorSource(ctx, u, c.Services.CredentialsSource(), httpTimeout) + source = getproviders.NewHTTPMirrorSource(ctx, u, c.Services.CredentialsSource(), httpTimeout, c.ProviderSourceLocationConfig) default: // With no special options we consult upstream registries directly, // because that gives us the most information to produce as complete // and portable as possible a lock entry. - source = getproviders.NewRegistrySource(ctx, c.Services, c.registryHTTPClient(ctx)) + source = getproviders.NewRegistrySource(ctx, c.Services, c.registryHTTPClient(ctx), c.ProviderSourceLocationConfig) } config, confDiags := c.loadConfig(ctx, ".") diff --git a/internal/command/providers_mirror.go b/internal/command/providers_mirror.go index 000b42158a..167e055829 100644 --- a/internal/command/providers_mirror.go +++ b/internal/command/providers_mirror.go @@ -112,7 +112,7 @@ func (c *ProvidersMirrorCommand) Run(args []string) int { // directory without needing to first disable that local mirror // in the CLI configuration. source := getproviders.NewMemoizeSource( - getproviders.NewRegistrySource(ctx, c.Services, c.registryHTTPClient(ctx)), + getproviders.NewRegistrySource(ctx, c.Services, c.registryHTTPClient(ctx), c.ProviderSourceLocationConfig), ) // Providers from registries always use HTTP, so we don't need the full @@ -190,7 +190,7 @@ func (c *ProvidersMirrorCommand) Run(args []string) int { )) continue } - urlStr, ok := meta.Location.(getproviders.PackageHTTPURL) + httpPkg, ok := meta.Location.(getproviders.PackageHTTPURL) if !ok { // We don't expect to get non-HTTP locations here because we're // using the registry source, so this seems like a bug in the @@ -202,7 +202,7 @@ func (c *ProvidersMirrorCommand) Run(args []string) int { )) continue } - urlObj, err := url.Parse(string(urlStr)) + urlObj, err := url.Parse(httpPkg.URL) if err != nil { // We don't expect to get non-HTTP locations here because we're // using the registry source, so this seems like a bug in the diff --git a/internal/getproviders/http_mirror_source.go b/internal/getproviders/http_mirror_source.go index 71500eeef4..1b41e11db9 100644 --- a/internal/getproviders/http_mirror_source.go +++ b/internal/getproviders/http_mirror_source.go @@ -31,9 +31,10 @@ import ( // HTTPMirrorSource is a source that reads provider metadata from a provider // mirror that is accessible over the HTTP provider mirror protocol. type HTTPMirrorSource struct { - baseURL *url.URL - creds svcauth.CredentialsSource - httpClient *retryablehttp.Client + baseURL *url.URL + creds svcauth.CredentialsSource + httpClient *retryablehttp.Client + locationConfig LocationConfig } var _ Source = (*HTTPMirrorSource)(nil) @@ -46,7 +47,7 @@ var _ Source = (*HTTPMirrorSource)(nil) // (When the URL comes from user input, such as in the CLI config, it's the // UI/config layer's responsibility to validate this and return a suitable // error message for the end-user audience.) -func NewHTTPMirrorSource(ctx context.Context, baseURL *url.URL, creds svcauth.CredentialsSource, requestTimeout time.Duration) *HTTPMirrorSource { +func NewHTTPMirrorSource(ctx context.Context, baseURL *url.URL, creds svcauth.CredentialsSource, requestTimeout time.Duration, sourceLocationCfg LocationConfig) *HTTPMirrorSource { httpClient := httpclient.NewForRegistryRequests(ctx, 0, requestTimeout) httpClient.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { // If we get redirected more than five times we'll assume we're @@ -56,17 +57,18 @@ func NewHTTPMirrorSource(ctx context.Context, baseURL *url.URL, creds svcauth.Cr } return nil } - return newHTTPMirrorSourceWithHTTPClient(baseURL, creds, httpClient) + return newHTTPMirrorSourceWithHTTPClient(baseURL, creds, httpClient, sourceLocationCfg) } -func newHTTPMirrorSourceWithHTTPClient(baseURL *url.URL, creds svcauth.CredentialsSource, httpClient *retryablehttp.Client) *HTTPMirrorSource { +func newHTTPMirrorSourceWithHTTPClient(baseURL *url.URL, creds svcauth.CredentialsSource, httpClient *retryablehttp.Client, sourceLocationCfg LocationConfig) *HTTPMirrorSource { if baseURL.Scheme != "https" { panic("non-https URL for HTTP mirror") } return &HTTPMirrorSource{ - baseURL: baseURL, - creds: creds, - httpClient: httpClient, + baseURL: baseURL, + creds: creds, + httpClient: httpClient, + locationConfig: sourceLocationCfg, } } @@ -209,7 +211,9 @@ func (s *HTTPMirrorSource) PackageMeta(ctx context.Context, provider addrs.Provi Version: version, TargetPlatform: target, - Location: PackageHTTPURL(absURL.String()), + Location: PackageHTTPURL{URL: absURL.String(), ClientBuilder: func(ctx context.Context) *retryablehttp.Client { + return packageHTTPUrlClientWithRetry(ctx, s.locationConfig.ProviderDownloadRetries) + }}, Filename: path.Base(absURL.Path), } // A network mirror might not provide any hashes at all, in which case diff --git a/internal/getproviders/http_mirror_source_test.go b/internal/getproviders/http_mirror_source_test.go index 0588dbe303..c745e2ef83 100644 --- a/internal/getproviders/http_mirror_source_test.go +++ b/internal/getproviders/http_mirror_source_test.go @@ -11,10 +11,13 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" + "github.com/apparentlymart/go-versions/versions" "github.com/google/go-cmp/cmp" "github.com/hashicorp/go-retryablehttp" + "github.com/opentofu/opentofu/internal/httpclient" "github.com/opentofu/svchost" "github.com/opentofu/svchost/svcauth" @@ -38,7 +41,7 @@ func TestHTTPMirrorSource(t *testing.T) { }) retryHTTPClient := retryablehttp.NewClient() retryHTTPClient.HTTPClient = httpClient - source := newHTTPMirrorSourceWithHTTPClient(baseURL, creds, retryHTTPClient) + source := newHTTPMirrorSourceWithHTTPClient(baseURL, creds, retryHTTPClient, LocationConfig{}) existingProvider := addrs.MustParseProviderSourceString("terraform.io/test/exists") missingProvider := addrs.MustParseProviderSourceString("terraform.io/test/missing") @@ -47,6 +50,22 @@ func TestHTTPMirrorSource(t *testing.T) { redirectLoopProvider := addrs.MustParseProviderSourceString("terraform.io/test/redirect-loop") tosPlatform := Platform{OS: "tos", Arch: "m68k"} + clientBuilderFromHTTPLocation := func(t *testing.T, expectedRetries int) func(ctx context.Context) *retryablehttp.Client { + return func(ctx context.Context) *retryablehttp.Client { + return packageHTTPUrlClientWithRetry(ctx, expectedRetries) + } + } + // For the PackageHTTPURL.ClientBuilder we are interested strictly in comparing the max retries and nothing else. + // Comparing the whole client gets into different issues so we keep this comparison simple. + cmpClientBuilder := cmp.Comparer(func(a, b func(ctx context.Context) *retryablehttp.Client) bool { + given := a(t.Context()) + got := b(t.Context()) + if got.RetryMax != given.RetryMax { + t.Logf("expected to have retry max as %d but got %d", got.RetryMax, given.RetryMax) + } + return true + }) + t.Run("AvailableVersions for provider that exists", func(t *testing.T) { got, _, err := source.AvailableVersions(context.Background(), existingProvider) if err != nil { @@ -73,7 +92,7 @@ func TestHTTPMirrorSource(t *testing.T) { } }) t.Run("AvailableVersions without required credentials", func(t *testing.T) { - unauthSource := newHTTPMirrorSourceWithHTTPClient(baseURL, nil, retryHTTPClient) + unauthSource := newHTTPMirrorSourceWithHTTPClient(baseURL, nil, retryHTTPClient, LocationConfig{}) _, _, err := unauthSource.AvailableVersions(context.Background(), existingProvider) switch err := err.(type) { case ErrUnauthorized: @@ -128,14 +147,14 @@ func TestHTTPMirrorSource(t *testing.T) { Version: version, TargetPlatform: tosPlatform, Filename: "terraform-provider-test_v1.0.0_tos_m68k.zip", - Location: PackageHTTPURL(httpServer.URL + "/terraform.io/test/exists/terraform-provider-test_v1.0.0_tos_m68k.zip"), + Location: PackageHTTPURL{URL: httpServer.URL + "/terraform.io/test/exists/terraform-provider-test_v1.0.0_tos_m68k.zip", ClientBuilder: clientBuilderFromHTTPLocation(t, retryHTTPClient.RetryMax)}, Authentication: packageHashAuthentication{ RequiredHashes: []Hash{"h1:placeholder-hash"}, AllHashes: []Hash{"h1:placeholder-hash", "h0:unacceptable-hash"}, Platform: Platform{"tos", "m68k"}, }, } - if diff := cmp.Diff(want, got); diff != "" { + if diff := cmp.Diff(want, got, cmpClientBuilder); diff != "" { t.Errorf("wrong result\n%s", diff) } }) @@ -151,10 +170,10 @@ func TestHTTPMirrorSource(t *testing.T) { Version: version, TargetPlatform: tosPlatform, Filename: "terraform-provider-test_v1.0.1_tos_m68k.zip", - Location: PackageHTTPURL(httpServer.URL + "/terraform.io/test/exists/terraform-provider-test_v1.0.1_tos_m68k.zip"), + Location: PackageHTTPURL{URL: httpServer.URL + "/terraform.io/test/exists/terraform-provider-test_v1.0.1_tos_m68k.zip", ClientBuilder: clientBuilderFromHTTPLocation(t, retryHTTPClient.RetryMax)}, Authentication: nil, } - if diff := cmp.Diff(want, got); diff != "" { + if diff := cmp.Diff(want, got, cmpClientBuilder); diff != "" { t.Errorf("wrong result\n%s", diff) } }) @@ -191,9 +210,9 @@ func TestHTTPMirrorSource(t *testing.T) { // NOTE: The final URL is interpreted relative to the redirect // target, not relative to what we originally requested. - Location: PackageHTTPURL(httpServer.URL + "/redirect-target/terraform-provider-test.zip"), + Location: PackageHTTPURL{URL: httpServer.URL + "/redirect-target/terraform-provider-test.zip", ClientBuilder: clientBuilderFromHTTPLocation(t, retryHTTPClient.RetryMax)}, } - if diff := cmp.Diff(want, got); diff != "" { + if diff := cmp.Diff(want, got, cmpClientBuilder); diff != "" { t.Errorf("wrong result\n%s", diff) } }) @@ -219,7 +238,10 @@ func testHTTPMirrorSourceHandler(resp http.ResponseWriter, req *http.Request) { resp.WriteHeader(401) fmt.Fprintln(resp, "incorrect auth token") } + testHTTPMirrorSourceHandlerNoAuth(resp, req) +} +func testHTTPMirrorSourceHandlerNoAuth(resp http.ResponseWriter, req *http.Request) { switch req.URL.Path { case "/terraform.io/test/exists/index.json": resp.Header().Add("Content-Type", "application/json; ignored=yes") @@ -317,8 +339,73 @@ func testHTTPMirrorSourceHandler(resp http.ResponseWriter, req *http.Request) { resp.WriteHeader(301) fmt.Fprint(resp, "redirect loop") + case "/terraform.io/missing/providerbinary/1.2.0.json": + resp.Header().Add("Content-Type", "application/json; ignored=yes") + resp.WriteHeader(200) + fmt.Fprint(resp, ` + { + "archives": { + "tos_m68k": { + "url": "terraform-missing-providerbinary_v1.2.0_tos_m68k.zip", + "hashes": [ + "h1:placeholder-hash", + "h0:unacceptable-hash" + ] + } + } + } + `) + case "/terraform.io/missing/providerbinary/terraform-missing-providerbinary_v1.2.0_tos_m68k.zip": + resp.Header().Add("Content-Type", "application/json; ignored=yes") + // Just return a retryable status code + resp.WriteHeader(http.StatusInternalServerError) + default: resp.WriteHeader(404) fmt.Fprintln(resp, "not found") } } + +// Checks that the [LocationConfig] is used properly to configure the [PackageHTTPURL] http client, +// meaning that the retries are configured as expected. +func TestHTTPMirrorLocationRetriesConfiguredCorrectly(t *testing.T) { + // Using the NoAuth handler and NoTLSServer here because the http client + // used to configure the [PackageHTTPURL] does not have the credentials and the + // CA forwarded from the HTTPMirrorSource. + httpServer := httptest.NewServer(http.HandlerFunc(testHTTPMirrorSourceHandlerNoAuth)) + defer httpServer.Close() + baseURL, err := url.Parse(httpServer.URL) + if err != nil { + t.Fatalf("unexpected error parsing the url: %s", err) + } + creds := svcauth.StaticCredentialsSource(map[svchost.Hostname]svcauth.HostCredentials{ + svchost.Hostname(baseURL.Host): svcauth.HostCredentialsToken("placeholder-token"), + }) + retryHTTPClient := retryablehttp.NewClient() + retryHTTPClient.HTTPClient = httpclient.New(t.Context()) + // The same reason as above applies on why using here directly the struct for initialisation + // instead of calling [newHTTPMirrorSourceWithHTTPClient]. We have no way to forward the CA + // from the http client got from the TLS server into the [PackageHTTPURL] inner client. + source := &HTTPMirrorSource{ + baseURL: baseURL, + creds: creds, + httpClient: retryHTTPClient, + locationConfig: LocationConfig{ProviderDownloadRetries: 2}, + } + + providerAddr := addrs.MustParseProviderSourceString("terraform.io/missing/providerbinary") + version := versions.MustParseVersion("1.2.0") + platform := Platform{OS: "tos", Arch: "m68k"} + got, err := source.PackageMeta(t.Context(), providerAddr, version, platform) + if err != nil { + t.Fatalf("unexpected error got from packageMeta: %s", err) + } + tmp := t.TempDir() + _, err = got.Location.InstallProviderPackage(t.Context(), got, tmp, nil) + if err == nil { + t.Fatalf("expected error but got nothing") + } + if expectedSuffix := "giving up after 3 attempt(s)"; !strings.HasSuffix(err.Error(), expectedSuffix) { + t.Fatalf("expected err %q to have suffix %q", err.Error(), expectedSuffix) + } +} diff --git a/internal/getproviders/location_config.go b/internal/getproviders/location_config.go new file mode 100644 index 0000000000..da6fa19954 --- /dev/null +++ b/internal/getproviders/location_config.go @@ -0,0 +1,19 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package getproviders + +import "time" + +// LocationConfig provides configuration options +// to be carried from the [Source] to the [PackageLocation] where applicable. +type LocationConfig struct { + // ProviderDownloadRetries is used for [PackageHTTPURL] to configure the + // http client to retry when a 5xx retryable error occurs. + ProviderDownloadRetries int + + // TODO - use this when we'll introduce per installation method configuration + ProviderDownloadTimeout time.Duration +} diff --git a/internal/getproviders/mock_source.go b/internal/getproviders/mock_source.go index a03d2dfea1..2d1803c437 100644 --- a/internal/getproviders/mock_source.go +++ b/internal/getproviders/mock_source.go @@ -135,7 +135,7 @@ func FakePackageMeta(provider addrs.Provider, version Version, protocols Version // Some fake but somewhat-realistic-looking other metadata. This // points nowhere, so will fail if attempting to actually use it. Filename: fmt.Sprintf("terraform-provider-%s_%s_%s.zip", provider.Type, version.String(), target.String()), - Location: PackageHTTPURL(fmt.Sprintf("https://fake.invalid/terraform-provider-%s_%s.zip", provider.Type, version.String())), + Location: PackageHTTPURL{URL: fmt.Sprintf("https://fake.invalid/terraform-provider-%s_%s.zip", provider.Type, version.String())}, } } diff --git a/internal/getproviders/package_location_http_archive.go b/internal/getproviders/package_location_http_archive.go index 01c084308e..e870881b74 100644 --- a/internal/getproviders/package_location_http_archive.go +++ b/internal/getproviders/package_location_http_archive.go @@ -11,15 +11,14 @@ import ( "log" "net/http" "os" - "strconv" "github.com/hashicorp/go-getter" "github.com/hashicorp/go-retryablehttp" + "github.com/opentofu/opentofu/internal/httpclient" + "github.com/opentofu/opentofu/internal/logging" semconv "go.opentelemetry.io/otel/semconv/v1.30.0" "go.opentelemetry.io/otel/trace" - "github.com/opentofu/opentofu/internal/httpclient" - "github.com/opentofu/opentofu/internal/logging" "github.com/opentofu/opentofu/internal/tracing" ) @@ -28,11 +27,28 @@ import ( // Its value is a URL string using either the http: scheme or the https: scheme. // The URL should respond with a .zip archive whose contents are to be extracted // into a local package directory. -type PackageHTTPURL string +// +// This type evolved from a single specific HTTP URL location to a more evolved +// type since we had to extract the logic of making an HTTP request into a more +// configurable piece to be able to provider a more specific-per-configuration +// http client. +type PackageHTTPURL struct { + // URL indicates the fully qualified location where the package + // resides and should be used as it is to perform a request. + URL string + // ClientBuilder is the external given function that is meant to + // construct a [*retryablehttp.Client] for making the actual request + // to the [PackageHTTPURL.URL] specified above. + // This came from the need of unifying the retries and timeout configurations + // into a single place. This way, the "creator" of this struct + // can inject a client by its liking to customize the requests + // accordingly. + ClientBuilder func(ctx context.Context) *retryablehttp.Client +} -var _ PackageLocation = PackageHTTPURL("") +var _ PackageLocation = PackageHTTPURL{} -func (p PackageHTTPURL) String() string { return string(p) } +func (p PackageHTTPURL) String() string { return p.URL } func (p PackageHTTPURL) InstallProviderPackage(ctx context.Context, meta PackageMeta, targetDir string, allowedHashes []Hash) (*PackageAuthenticationResult, error) { url := meta.Location.String() @@ -50,15 +66,7 @@ func (p PackageHTTPURL) InstallProviderPackage(ctx context.Context, meta Package // through X-Terraform-Get header, attempting partial fetches for // files that already exist, etc.) - retryableClient := retryablehttp.NewClient() - retryableClient.HTTPClient = httpclient.New(ctx) - retryableClient.RetryMax = maxHTTPPackageRetryCount - retryableClient.RequestLogHook = func(logger retryablehttp.Logger, _ *http.Request, i int) { - if i > 0 { - logger.Printf("[INFO] failed to fetch provider package; retrying") - } - } - retryableClient.Logger = log.New(logging.LogOutput(), "", log.Flags()) + retryableClient := p.ClientBuilder(ctx) req, err := retryablehttp.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { @@ -126,31 +134,20 @@ func (p PackageHTTPURL) InstallProviderPackage(ctx context.Context, meta Package return authResult, nil } -const ( - // httpClientRetryCountEnvName is the environment variable name used to customize - // the HTTP retry count for module downloads. - httpClientRetryCountEnvName = "TF_PROVIDER_DOWNLOAD_RETRY" - - httpClientDefaultRetry = 2 -) - -//nolint:gochecknoinits // this init function predates our use of this linter -func init() { - configureProviderDownloadRetry() -} - -var ( - //nolint:gochecknoglobals // this variable predates our use of this linter - maxHTTPPackageRetryCount int -) - -// will attempt for requests with retryable errors, like 502 status codes -func configureProviderDownloadRetry() { - maxHTTPPackageRetryCount = httpClientDefaultRetry - if v := os.Getenv(httpClientRetryCountEnvName); v != "" { - retry, err := strconv.Atoi(v) - if err == nil && retry > 0 { - maxHTTPPackageRetryCount = retry +// packageHTTPUrlClientWithRetry is the extracted logic from the [PackageHTTPURL.InstallProviderPackage] to be +// able to reuse the same logic with a custom retry. +// This is kept as it was previously, before being moved here, to avoid introducing unwanted behaviors +// in a package download process. +// Later, this method might be removed in favor of a more common client from [httpclient] package. +func packageHTTPUrlClientWithRetry(ctx context.Context, retries int) *retryablehttp.Client { + retryableClient := retryablehttp.NewClient() + retryableClient.HTTPClient = httpclient.New(ctx) + retryableClient.RetryMax = retries + retryableClient.RequestLogHook = func(logger retryablehttp.Logger, _ *http.Request, i int) { + if i > 0 { + logger.Printf("[INFO] failed to fetch provider package; retrying") } } + retryableClient.Logger = log.New(logging.LogOutput(), "", log.Flags()) + return retryableClient } diff --git a/internal/getproviders/registry_client.go b/internal/getproviders/registry_client.go index 626059b005..09cfd8fec8 100644 --- a/internal/getproviders/registry_client.go +++ b/internal/getproviders/registry_client.go @@ -44,13 +44,16 @@ type registryClient struct { creds svcauth.HostCredentials httpClient *retryablehttp.Client + + locationConfig LocationConfig } -func newRegistryClient(ctx context.Context, baseURL *url.URL, creds svcauth.HostCredentials, httpClient *retryablehttp.Client) *registryClient { +func newRegistryClient(ctx context.Context, baseURL *url.URL, creds svcauth.HostCredentials, httpClient *retryablehttp.Client, locationConfig LocationConfig) *registryClient { return ®istryClient{ - baseURL: baseURL, - creds: creds, - httpClient: httpClient, + baseURL: baseURL, + creds: creds, + httpClient: httpClient, + locationConfig: locationConfig, } } @@ -293,7 +296,9 @@ func (c *registryClient) PackageMeta(ctx context.Context, provider addrs.Provide Arch: body.Arch, }, Filename: body.Filename, - Location: PackageHTTPURL(downloadURL.String()), + Location: PackageHTTPURL{URL: downloadURL.String(), ClientBuilder: func(ctx context.Context) *retryablehttp.Client { + return packageHTTPUrlClientWithRetry(ctx, c.locationConfig.ProviderDownloadRetries) + }}, // "Authentication" is populated below } diff --git a/internal/getproviders/registry_client_test.go b/internal/getproviders/registry_client_test.go index 8b095b72cd..df12130582 100644 --- a/internal/getproviders/registry_client_test.go +++ b/internal/getproviders/registry_client_test.go @@ -16,10 +16,9 @@ import ( "github.com/apparentlymart/go-versions/versions" "github.com/google/go-cmp/cmp" + "github.com/opentofu/opentofu/internal/addrs" "github.com/opentofu/svchost" disco "github.com/opentofu/svchost/disco" - - "github.com/opentofu/opentofu/internal/addrs" ) // testRegistryServices starts up a local HTTP server running a fake provider registry @@ -75,7 +74,13 @@ func testRegistryServices(t *testing.T) (services *disco.Disco, baseURL string, // of your test in order to shut down the test server. func testRegistrySource(t *testing.T) (source *RegistrySource, baseURL string, cleanup func()) { services, baseURL, close := testRegistryServices(t) - source = NewRegistrySource(t.Context(), services, nil) + source = NewRegistrySource(t.Context(), services, nil, LocationConfig{ProviderDownloadRetries: 0}) + return source, baseURL, close +} + +func testRegistrySourceWithLocationConfig(t *testing.T, config LocationConfig) (source *RegistrySource, baseURL string, cleanup func()) { + services, baseURL, close := testRegistryServices(t) + source = NewRegistrySource(t.Context(), services, nil, config) return source, baseURL, close } @@ -117,6 +122,13 @@ func fakeRegistryHandler(resp http.ResponseWriter, req *http.Request) { write([]byte("000000000000000000000000000000000000000000000000000000000000f00d happycloud_1.2.0.zip\n000000000000000000000000000000000000000000000000000000000000face happycloud_1.2.0_face.zip\n")) case "/pkg/awesomesauce/happycloud_1.2.0_SHA256SUMS.sig": write([]byte("GPG signature")) + case "/pkg/missing/providerbinary_1.2.0.zip": + // Just return a retryable status code + resp.WriteHeader(http.StatusInternalServerError) + case "/pkg/missing/providerbinary_1.2.0_SHA256SUMS": + write([]byte("000000000000000000000000000000000000000000000000000000000000f00d providerbinary_1.2.0.zip\n000000000000000000000000000000000000000000000000000000000000face providerbinary_1.2.0_face.zip\n")) + case "/pkg/missing/providerbinary_1.2.0_SHA256SUMS.sig": + write([]byte("GPG signature")) default: resp.WriteHeader(404) write([]byte("unknown package file download")) @@ -194,7 +206,9 @@ func fakeRegistryHandler(resp http.ResponseWriter, req *http.Request) { if len(pathParts) == 6 && pathParts[3] == "download" { switch pathParts[0] + "/" + pathParts[1] { - case "awesomesauce/happycloud": + case "awesomesauce/happycloud", "missing/providerbinary": + pNamespace := pathParts[0] + pType := pathParts[1] if pathParts[4] == "nonexist" { resp.WriteHeader(404) write([]byte(`unsupported OS`)) @@ -215,11 +229,11 @@ func fakeRegistryHandler(resp http.ResponseWriter, req *http.Request) { "protocols": protocols, "os": pathParts[4], "arch": pathParts[5], - "filename": "happycloud_" + version + ".zip", + "filename": fmt.Sprintf("%s_%s.zip", pType, version), "shasum": "000000000000000000000000000000000000000000000000000000000000f00d", - "download_url": "/pkg/awesomesauce/happycloud_" + version + ".zip", - "shasums_url": "/pkg/awesomesauce/happycloud_" + version + "_SHA256SUMS", - "shasums_signature_url": "/pkg/awesomesauce/happycloud_" + version + "_SHA256SUMS.sig", + "download_url": fmt.Sprintf("/pkg/%s/%s_%s.zip", pNamespace, pType, version), + "shasums_url": fmt.Sprintf("/pkg/%s/%s_%s_SHA256SUMS", pNamespace, pType, version), + "shasums_signature_url": fmt.Sprintf("/pkg/%s/%s_%s_SHA256SUMS.sig", pNamespace, pType, version), "signing_keys": map[string]interface{}{ "gpg_public_keys": []map[string]interface{}{ { @@ -393,3 +407,32 @@ func TestFindClosestProtocolCompatibleVersion(t *testing.T) { }) } } + +// Checks that the [LocationConfig] is used properly to configure the [PackageHTTPURL] http client, +// meaning that the retries are configured as expected. +func TestLocationRetriesConfiguredCorrectly(t *testing.T) { + source, _, close := testRegistrySourceWithLocationConfig(t, LocationConfig{ProviderDownloadRetries: 2}) + defer close() + + parts := strings.Split("example.com/missing/providerbinary", "/") + providerAddr := addrs.Provider{ + Hostname: svchost.Hostname(parts[0]), + Namespace: parts[1], + Type: parts[2], + } + + version := versions.MustParseVersion("1.2.0") + + got, err := source.PackageMeta(t.Context(), providerAddr, version, Platform{"linux", "amd64"}) + if err != nil { + t.Fatalf("unexpected error got from packageMeta: %s", err) + } + tmp := t.TempDir() + _, err = got.Location.InstallProviderPackage(t.Context(), got, tmp, nil) + if err == nil { + t.Fatalf("expected error but got nothing") + } + if expectedSuffix := "giving up after 3 attempt(s)"; !strings.HasSuffix(err.Error(), expectedSuffix) { + t.Fatalf("expected err %q to have suffix %q", err.Error(), expectedSuffix) + } +} diff --git a/internal/getproviders/registry_source.go b/internal/getproviders/registry_source.go index 630ba37268..d9fbf5f4a9 100644 --- a/internal/getproviders/registry_source.go +++ b/internal/getproviders/registry_source.go @@ -23,13 +23,15 @@ import ( type RegistrySource struct { services *disco.Disco httpClient *retryablehttp.Client + + locationConfig LocationConfig } var _ Source = (*RegistrySource)(nil) // NewRegistrySource creates and returns a new source that will install // providers from their originating provider registries. -func NewRegistrySource(ctx context.Context, services *disco.Disco, httpClient *retryablehttp.Client) *RegistrySource { +func NewRegistrySource(ctx context.Context, services *disco.Disco, httpClient *retryablehttp.Client, locationCfg LocationConfig) *RegistrySource { if httpClient == nil { // As an aid to our tests that don't really care that much about // the HTTP client configuration, we'll provide some reasonable @@ -38,8 +40,9 @@ func NewRegistrySource(ctx context.Context, services *disco.Disco, httpClient *r } return &RegistrySource{ - services: services, - httpClient: httpClient, + services: services, + httpClient: httpClient, + locationConfig: locationCfg, } } @@ -159,7 +162,7 @@ func (s *RegistrySource) registryClient(ctx context.Context, hostname svchost.Ho return nil, fmt.Errorf("failed to retrieve credentials for %s: %w", hostname, err) } - return newRegistryClient(ctx, url, creds, s.httpClient), nil + return newRegistryClient(ctx, url, creds, s.httpClient, s.locationConfig), nil } func (s *RegistrySource) ForDisplay(provider addrs.Provider) string { diff --git a/internal/getproviders/registry_source_test.go b/internal/getproviders/registry_source_test.go index 6c334a9e17..7c1196a5ed 100644 --- a/internal/getproviders/registry_source_test.go +++ b/internal/getproviders/registry_source_test.go @@ -14,6 +14,7 @@ import ( "github.com/apparentlymart/go-versions/versions" "github.com/google/go-cmp/cmp" + "github.com/hashicorp/go-retryablehttp" regaddr "github.com/opentofu/registry-address/v2" "github.com/opentofu/svchost" @@ -114,7 +115,7 @@ func TestSourceAvailableVersions_warnings(t *testing.T) { } func TestSourcePackageMeta(t *testing.T) { - source, baseURL, close := testRegistrySource(t) + source, baseURL, close := testRegistrySourceWithLocationConfig(t, LocationConfig{ProviderDownloadRetries: 3}) defer close() validMeta := PackageMeta{ @@ -125,7 +126,9 @@ func TestSourcePackageMeta(t *testing.T) { ProtocolVersions: VersionList{versions.MustParseVersion("5.0.0")}, TargetPlatform: Platform{"linux", "amd64"}, Filename: "happycloud_1.2.0.zip", - Location: PackageHTTPURL(baseURL + "/pkg/awesomesauce/happycloud_1.2.0.zip"), + Location: PackageHTTPURL{URL: baseURL + "/pkg/awesomesauce/happycloud_1.2.0.zip", ClientBuilder: func(ctx context.Context) *retryablehttp.Client { + return packageHTTPUrlClientWithRetry(ctx, source.locationConfig.ProviderDownloadRetries) + }}, } validMeta.Authentication = PackageAuthenticationAll( NewMatchingChecksumAuthentication( @@ -196,7 +199,14 @@ func TestSourcePackageMeta(t *testing.T) { // consistently match those. Instead, we'll normalize the URLs. urlPattern := regexp.MustCompile(`http://[^/]+/`) - cmpOpts := cmp.Comparer(Version.Same) + cmpOpts := []cmp.Option{cmp.Comparer(Version.Same), cmp.Comparer(func(a, b func(ctx context.Context) *retryablehttp.Client) bool { + given := a(t.Context()) + got := b(t.Context()) + if got.RetryMax != given.RetryMax { + t.Logf("expected to have retry max as %d but got %d", got.RetryMax, given.RetryMax) + } + return true + })} for _, test := range tests { t.Run(fmt.Sprintf("%s for %s_%s", test.provider, test.os, test.arch), func(t *testing.T) { @@ -228,7 +238,7 @@ func TestSourcePackageMeta(t *testing.T) { t.Fatalf("wrong error\ngot: \nwant: %s", test.wantErr) } - if diff := cmp.Diff(got, test.want, cmpOpts); diff != "" { + if diff := cmp.Diff(got, test.want, cmpOpts...); diff != "" { t.Errorf("wrong result\n%s", diff) } }) diff --git a/internal/providercache/installer_test.go b/internal/providercache/installer_test.go index c0d5c44b57..f82548aedb 100644 --- a/internal/providercache/installer_test.go +++ b/internal/providercache/installer_test.go @@ -2745,7 +2745,7 @@ func testServices(t *testing.T) (services *disco.Disco, baseURL string, cleanup // of your test in order to shut down the test server. func testRegistrySource(t *testing.T) (source *getproviders.RegistrySource, baseURL string, cleanup func()) { services, baseURL, close := testServices(t) - source = getproviders.NewRegistrySource(t.Context(), services, nil) + source = getproviders.NewRegistrySource(t.Context(), services, nil, getproviders.LocationConfig{ProviderDownloadRetries: 0}) return source, baseURL, close }