diff --git a/certbot/cli.py b/certbot/cli.py index cb769872e..e100c7715 100644 --- a/certbot/cli.py +++ b/certbot/cli.py @@ -515,6 +515,13 @@ class HelpfulArgumentParser(object): return usage + def remove_config_file_domains_for_renewal(self, parsed_args): + """Make "certbot renew" safe if domains are set in cli.ini.""" + # Works around https://github.com/certbot/certbot/issues/4096 + if self.verb == "renew": + for source, flags in self.parser._source_to_settings.items(): # pylint: disable=protected-access + if source.startswith("config_file") and "domains" in flags: + parsed_args.domains = _Default() if self.detect_defaults else [] def parse_args(self): """Parses command line arguments and returns the result. @@ -527,6 +534,8 @@ class HelpfulArgumentParser(object): parsed_args.func = self.VERBS[self.verb] parsed_args.verb = self.verb + self.remove_config_file_domains_for_renewal(parsed_args) + if self.detect_defaults: return parsed_args diff --git a/certbot/tests/cli_test.py b/certbot/tests/cli_test.py index 498bd309d..0a5e959c2 100644 --- a/certbot/tests/cli_test.py +++ b/certbot/tests/cli_test.py @@ -39,6 +39,7 @@ class TestReadFile(TempDirTestCase): self.assertEqual(contents, test_contents) + class ParseTest(unittest.TestCase): '''Test the cli args entrypoint''' @@ -61,6 +62,22 @@ class ParseTest(unittest.TestCase): self.assertRaises(SystemExit, self.parse, args, output) return output.getvalue() + @mock.patch("certbot.cli.flag_default") + def test_cli_ini_domains(self, mock_flag_default): + tmp_config = tempfile.NamedTemporaryFile() + # use a shim to get ConfigArgParse to pick up tmp_config + shim = lambda v: constants.CLI_DEFAULTS[v] if v != "config_files" else [tmp_config.name] + mock_flag_default.side_effect = shim + + namespace = self.parse(["certonly"]) + self.assertEqual(namespace.domains, []) + tmp_config.write(b"domains = example.com") + tmp_config.flush() + namespace = self.parse(["certonly"]) + self.assertEqual(namespace.domains, ["example.com"]) + namespace = self.parse(["renew"]) + self.assertEqual(namespace.domains, []) + def test_no_args(self): namespace = self.parse([]) for d in ('config_dir', 'logs_dir', 'work_dir'):