diff --git a/certbot/cert_manager.py b/certbot/cert_manager.py index ca4b60684..d841c1912 100644 --- a/certbot/cert_manager.py +++ b/certbot/cert_manager.py @@ -288,7 +288,7 @@ def human_readable_cert_info(config, cert, skip_filter_checks=False): cert.privkey)) return "".join(certinfo) -def get_certnames(config, verb, allow_multiple=False): +def get_certnames(config, verb, allow_multiple=False, custom_prompt=None): """Get certname from flag, interactively, or error out. """ certname = config.certname @@ -301,16 +301,22 @@ def get_certnames(config, verb, allow_multiple=False): if not choices: raise errors.Error("No existing certificates found.") if allow_multiple: + if not custom_prompt: + prompt = "Which certificate(s) would you like to {0}?".format(verb) + else: + prompt = custom_prompt code, certnames = disp.checklist( - "Which certificate(s) would you like to {0}?".format(verb), - choices, cli_flag="--cert-name", - force_interactive=True) + prompt, choices, cli_flag="--cert-name", force_interactive=True) if code != display_util.OK: raise errors.Error("User ended interaction.") else: - code, index = disp.menu("Which certificate would you like to {0}?".format(verb), - choices, cli_flag="--cert-name", - force_interactive=True) + if not custom_prompt: + prompt = "Which certificate would you like to {0}?".format(verb) + else: + prompt = custom_prompt + + code, index = disp.menu( + prompt, choices, cli_flag="--cert-name", force_interactive=True) if code != display_util.OK or index not in range(0, len(choices)): raise errors.Error("User ended interaction.") diff --git a/certbot/display/ops.py b/certbot/display/ops.py index f9d867b92..1e15a8474 100644 --- a/certbot/display/ops.py +++ b/certbot/display/ops.py @@ -95,10 +95,10 @@ def choose_values(values, question=None): :returns: List of selected values :rtype: list """ - code, names = z_util(interfaces.IDisplay).checklist( + code, items = z_util(interfaces.IDisplay).checklist( question, tags=values, force_interactive=True) - if code == display_util.OK and names: - return names + if code == display_util.OK and items: + return items else: return [] diff --git a/certbot/main.py b/certbot/main.py index 400783ec0..e753ccdaf 100644 --- a/certbot/main.py +++ b/certbot/main.py @@ -887,16 +887,23 @@ def enhance(config, plugins): except errors.PluginSelectionError as e: return str(e) - if not config.certname: - config.certname = cert_manager.get_certnames(config, "enhance", - allow_multiple=False)[0] + certname_question = ("Which certificate would you like to use to enhance " + "your configuration?") + config.certname = cert_manager.get_certnames( + config, "enhance", allow_multiple=False, + custom_prompt=certname_question)[0] cert_domains = cert_manager.domains_for_certname(config, config.certname) domain_question = ("Which domain names would you like to enable the selected " "enhancements for?") - domains = display_ops.choose_values(cert_domains, domain_question) + if config.noninteractive_mode: + domains = cert_manager.domains_for_certname(config, config.certname) + else: + domains = display_ops.choose_values(cert_domains, domain_question) if not domains: - # To be consistent with the error messages (similar to get_certnames) - raise errors.Error("User ended interaction.") + raise errors.Error("No domains found to enhance.") + if not config.chain_path: + lineage = cert_manager.lineage_for_certname(config, config.certname) + config.chain_path = lineage.chain_path le_client = _init_le_client(config, authenticator=None, installer=installer) le_client.enhance_config(domains, config.chain_path, ask_redirect=False) diff --git a/certbot/tests/main_test.py b/certbot/tests/main_test.py index cc8a82629..5e7fb32e2 100644 --- a/certbot/tests/main_test.py +++ b/certbot/tests/main_test.py @@ -1559,10 +1559,12 @@ class EnhanceTest(unittest.TestCase): main.enhance(config, plugins) return mock_client # returns the client - @mock.patch("certbot.main.plug_sel.record_chosen_plugins") - @mock.patch("certbot.main.display_ops.choose_values") - @mock.patch("certbot.main._find_domains_or_certname") - def test_selection_question(self, mock_find, mock_choose, _rec): + @mock.patch('certbot.main.plug_sel.record_chosen_plugins') + @mock.patch('certbot.cert_manager.lineage_for_certname') + @mock.patch('certbot.main.display_ops.choose_values') + @mock.patch('certbot.main._find_domains_or_certname') + def test_selection_question(self, mock_find, mock_choose, mock_lineage, _rec): + mock_lineage.return_value = mock.MagicMock(chain_path="/tmp/nonexistent") mock_choose.return_value = ['example.com'] mock_find.return_value = (None, None) with mock.patch('certbot.main.plug_sel.pick_installer') as mock_pick: @@ -1571,10 +1573,12 @@ class EnhanceTest(unittest.TestCase): # Check that the message includes "enhancements" self.assertTrue("enhancements" in mock_pick.call_args[0][3]) - @mock.patch("certbot.main.plug_sel.record_chosen_plugins") - @mock.patch("certbot.main.display_ops.choose_values") - @mock.patch("certbot.main._find_domains_or_certname") - def test_selection_auth_warning(self, mock_find, mock_choose, _rec): + @mock.patch('certbot.main.plug_sel.record_chosen_plugins') + @mock.patch('certbot.cert_manager.lineage_for_certname') + @mock.patch('certbot.main.display_ops.choose_values') + @mock.patch('certbot.main._find_domains_or_certname') + def test_selection_auth_warning(self, mock_find, mock_choose, mock_lineage, _rec): + mock_lineage.return_value = mock.MagicMock(chain_path="/tmp/nonexistent") mock_choose.return_value = ["example.com"] mock_find.return_value = (None, None) with mock.patch('certbot.main.plug_sel.pick_installer'): @@ -1584,9 +1588,11 @@ class EnhanceTest(unittest.TestCase): self.assertTrue("make sense" in mock_log.call_args[0][0]) self.assertTrue(mock_client.enhance_config.called) - @mock.patch("certbot.main.display_ops.choose_values") - @mock.patch("certbot.main.plug_sel.record_chosen_plugins") - def test_enhance_config_call(self, _rec, mock_choose): + @mock.patch('certbot.cert_manager.lineage_for_certname') + @mock.patch('certbot.main.display_ops.choose_values') + @mock.patch('certbot.main.plug_sel.record_chosen_plugins') + def test_enhance_config_call(self, _rec, mock_choose, mock_lineage): + mock_lineage.return_value = mock.MagicMock(chain_path="/tmp/nonexistent") mock_choose.return_value = ["example.com"] with mock.patch('certbot.main.plug_sel.pick_installer'): mock_client = self._call(['enhance', '--redirect', '--hsts']) @@ -1600,8 +1606,22 @@ class EnhanceTest(unittest.TestCase): self.assertTrue( "example.com" in mock_client.enhance_config.call_args[0][0]) - @mock.patch("certbot.main.display_ops.choose_values") - @mock.patch("certbot.main.plug_sel.record_chosen_plugins") + @mock.patch('certbot.cert_manager.lineage_for_certname') + @mock.patch('certbot.main.display_ops.choose_values') + @mock.patch('certbot.main.plug_sel.record_chosen_plugins') + def test_enhance_noninteractive(self, _rec, mock_choose, mock_lineage): + mock_lineage.return_value = mock.MagicMock( + chain_path="/tmp/nonexistent") + mock_choose.return_value = ["example.com"] + with mock.patch('certbot.main.plug_sel.pick_installer'): + mock_client = self._call(['enhance', '--redirect', + '--hsts', '--non-interactive']) + self.assertTrue(mock_client.enhance_config.called) + self.assertFalse(mock_choose.called) + + + @mock.patch('certbot.main.display_ops.choose_values') + @mock.patch('certbot.main.plug_sel.record_chosen_plugins') def test_user_abort_domains(self, _rec, mock_choose): mock_choose.return_value = [] with mock.patch('certbot.main.plug_sel.pick_installer'): @@ -1613,9 +1633,9 @@ class EnhanceTest(unittest.TestCase): self.assertRaises(errors.MisconfigurationError, self._call, ['enhance']) - @mock.patch("certbot.main.plug_sel.choose_configurator_plugins") - @mock.patch("certbot.main.display_ops.choose_values") - @mock.patch("certbot.main.plug_sel.record_chosen_plugins") + @mock.patch('certbot.main.plug_sel.choose_configurator_plugins') + @mock.patch('certbot.main.display_ops.choose_values') + @mock.patch('certbot.main.plug_sel.record_chosen_plugins') def test_plugin_selection_error(self, _rec, mock_choose, mock_pick): mock_choose.return_value = ["example.com"] mock_pick.return_value = (None, None)