From 105abe7be58cca20ebeba7efc47bbe719105779d Mon Sep 17 00:00:00 2001 From: Joona Hoikkala Date: Tue, 3 Apr 2018 23:06:05 +0300 Subject: [PATCH] Interactive cert-name selection if not defined on CLI --- certbot/cert_manager.py | 14 ++++----- certbot/display/ops.py | 15 ++++++++++ certbot/main.py | 14 ++++++--- certbot/tests/main_test.py | 59 +++++++++++++++++++++++++++----------- 4 files changed, 75 insertions(+), 27 deletions(-) diff --git a/certbot/cert_manager.py b/certbot/cert_manager.py index 4240a0523..ca4b60684 100644 --- a/certbot/cert_manager.py +++ b/certbot/cert_manager.py @@ -46,7 +46,7 @@ def rename_lineage(config): """ disp = zope.component.getUtility(interfaces.IDisplay) - certname = _get_certnames(config, "rename")[0] + certname = get_certnames(config, "rename")[0] new_certname = config.new_certname if not new_certname: @@ -88,7 +88,7 @@ def certificates(config): def delete(config): """Delete Certbot files associated with a certificate lineage.""" - certnames = _get_certnames(config, "delete", allow_multiple=True) + certnames = get_certnames(config, "delete", allow_multiple=True) for certname in certnames: storage.delete_files(config, certname) disp = zope.component.getUtility(interfaces.IDisplay) @@ -288,11 +288,7 @@ def human_readable_cert_info(config, cert, skip_filter_checks=False): cert.privkey)) return "".join(certinfo) -################### -# Private Helpers -################### - -def _get_certnames(config, verb, allow_multiple=False): +def get_certnames(config, verb, allow_multiple=False): """Get certname from flag, interactively, or error out. """ certname = config.certname @@ -321,6 +317,10 @@ def _get_certnames(config, verb, allow_multiple=False): certnames = [choices[index]] return certnames +################### +# Private Helpers +################### + def _report_lines(msgs): """Format a results report for a category of single-line renewal outcomes""" return " " + "\n ".join(str(msg) for msg in msgs) diff --git a/certbot/display/ops.py b/certbot/display/ops.py index 2fdd07a65..f9d867b92 100644 --- a/certbot/display/ops.py +++ b/certbot/display/ops.py @@ -86,6 +86,21 @@ def choose_account(accounts): else: return None +def choose_values(values, question=None): + """Display screen to let user pick one or multiple values from the provided + list. + + :param list values: Values to select from + + :returns: List of selected values + :rtype: list + """ + code, names = z_util(interfaces.IDisplay).checklist( + question, tags=values, force_interactive=True) + if code == display_util.OK and names: + return names + else: + return [] def choose_names(installer, question=None): """Display screen to select domains to validate. diff --git a/certbot/main.py b/certbot/main.py index 797aadfec..d632f8c76 100644 --- a/certbot/main.py +++ b/certbot/main.py @@ -882,15 +882,21 @@ def enhance(config, plugins): logger.warning(msg, sys.argv[0]) raise errors.MisconfigurationError("No enhancements requested, exiting.") + if not config.certname: + config.certname = cert_manager.get_certnames(config, "enhance", + allow_multiple=False)[0] + cert_domains = cert_manager.domains_for_certname(config, config.certname) + domain_question = ("Which domain names would you like to enable the selected " + "enhancements for?") + domains = display_ops.choose_values(cert_domains, domain_question) + if not domains: + return try: installer, _ = plug_sel.choose_configurator_plugins(config, plugins, "enhance") except errors.PluginSelectionError as e: return str(e) le_client = _init_le_client(config, authenticator=None, installer=installer) - domain_question = ("Which domain names would you like to enable the selected " - "enhancements for") - domains, _ = _find_domains_or_certname(config, installer, domain_question) - le_client.enhance_config(domains, None, ask_redirect=False) + le_client.enhance_config(domains, config.chain_path, ask_redirect=False) def rollback(config, plugins): diff --git a/certbot/tests/main_test.py b/certbot/tests/main_test.py index 55beb5794..81f738688 100644 --- a/certbot/tests/main_test.py +++ b/certbot/tests/main_test.py @@ -1548,16 +1548,22 @@ class EnhanceTest(unittest.TestCase): config = configuration.NamespaceConfig( cli.prepare_and_parse_args(plugins, args)) - with mock.patch('certbot.main._init_le_client') as mock_init: - mock_client = mock.MagicMock() - mock_client.config = config - mock_init.return_value = mock_client - main.enhance(config, plugins) - return mock_client # returns the client + with mock.patch('certbot.cert_manager.get_certnames') as mock_certs: + mock_certs.return_value = ['example.com'] + with mock.patch('certbot.cert_manager.domains_for_certname') as mock_dom: + mock_dom.return_value = ['example.com'] + with mock.patch('certbot.main._init_le_client') as mock_init: + mock_client = mock.MagicMock() + mock_client.config = config + mock_init.return_value = mock_client + main.enhance(config, plugins) + return mock_client # returns the client @mock.patch("certbot.main.plug_sel.record_chosen_plugins") + @mock.patch("certbot.main.display_ops.choose_values") @mock.patch("certbot.main._find_domains_or_certname") - def test_selection_question(self, mock_find, _rec): + def test_selection_question(self, mock_find, mock_choose, _rec): + mock_choose.return_value = ['example.com'] mock_find.return_value = (None, None) with mock.patch('certbot.main.plug_sel.pick_installer') as mock_pick: self._call(['enhance', '--redirect']) @@ -1566,8 +1572,10 @@ class EnhanceTest(unittest.TestCase): self.assertTrue("enhancements" in mock_pick.call_args[0][3]) @mock.patch("certbot.main.plug_sel.record_chosen_plugins") + @mock.patch("certbot.main.display_ops.choose_values") @mock.patch("certbot.main._find_domains_or_certname") - def test_selection_auth_warning(self, mock_find, _rec): + def test_selection_auth_warning(self, mock_find, mock_choose, _rec): + mock_choose.return_value = ["example.com"] mock_find.return_value = (None, None) with mock.patch('certbot.main.plug_sel.pick_installer'): with mock.patch('certbot.main.plug_sel.logger.warning') as mock_log: @@ -1576,13 +1584,12 @@ class EnhanceTest(unittest.TestCase): self.assertTrue("make sense" in mock_log.call_args[0][0]) self.assertTrue(mock_client.enhance_config.called) + @mock.patch("certbot.main.display_ops.choose_values") @mock.patch("certbot.main.plug_sel.record_chosen_plugins") - def test_enhance_config_call(self, _rec): - # mock_find.return_value = (['domain.tld'], None) + def test_enhance_config_call(self, _rec, mock_choose): + mock_choose.return_value = ["example.com"] with mock.patch('certbot.main.plug_sel.pick_installer'): - mock_client = self._call(['enhance', '-d', 'domain.tld', - '-d', 'another.tld', '--redirect', - '--hsts']) + mock_client = self._call(['enhance', '--redirect', '--hsts']) req_enh = ["redirect", "hsts"] not_req_enh = ["uir"] self.assertTrue(mock_client.enhance_config.called) @@ -1591,9 +1598,29 @@ class EnhanceTest(unittest.TestCase): self.assertFalse( any([getattr(mock_client.config, e) for e in not_req_enh])) self.assertTrue( - "domain.tld" in mock_client.enhance_config.call_args[0][0]) - self.assertTrue( - "another.tld" in mock_client.enhance_config.call_args[0][0]) + "example.com" in mock_client.enhance_config.call_args[0][0]) + + @mock.patch("certbot.main.display_ops.choose_values") + @mock.patch("certbot.main.plug_sel.record_chosen_plugins") + def test_user_abort_domains(self, _rec, mock_choose): + mock_choose.return_value = [] + with mock.patch('certbot.main.plug_sel.pick_installer'): + mock_client = self._call(['enhance', '--redirect', '--hsts']) + self.assertFalse(mock_client.enhance_config.called) + + def test_no_enhancements_defined(self): + self.assertRaises(errors.MisconfigurationError, + self._call, ['enhance']) + + @mock.patch("certbot.main.plug_sel.choose_configurator_plugins") + @mock.patch("certbot.main.display_ops.choose_values") + @mock.patch("certbot.main.plug_sel.record_chosen_plugins") + def test_plugin_selection_error(self, _rec, mock_choose, mock_pick): + mock_choose.return_value = ["example.com"] + mock_pick.return_value = (None, None) + mock_pick.side_effect = errors.PluginSelectionError() + mock_client = self._call(['enhance', '--hsts']) + self.assertFalse(mock_client.enhance_config.called) if __name__ == '__main__': unittest.main() # pragma: no cover