From a39d2fe55b760707718dcd8225b7dc42dcef4c9c Mon Sep 17 00:00:00 2001 From: Brad Warren Date: Tue, 27 Feb 2018 18:05:33 -0800 Subject: [PATCH] Fix wildcard issuance (#5620) * Add is_wildcard_domain to certbot.util. * Error with --allow-subset-of-names and wildcards. * Fix issue preventing wildcard cert issuance. * Kill assumption domain is unique in auth_handler * fix typo and add test * update comments --- certbot/auth_handler.py | 158 ++++++++++++++++------------- certbot/cli.py | 5 + certbot/client.py | 7 +- certbot/tests/auth_handler_test.py | 80 +++++++-------- certbot/tests/cli_test.py | 4 + certbot/tests/client_test.py | 1 + certbot/tests/util_test.py | 20 ++++ certbot/util.py | 18 ++++ 8 files changed, 180 insertions(+), 113 deletions(-) diff --git a/certbot/auth_handler.py b/certbot/auth_handler.py index 9cc10d4b4..2b38e4af5 100644 --- a/certbot/auth_handler.py +++ b/certbot/auth_handler.py @@ -1,4 +1,5 @@ """ACME AuthHandler.""" +import collections import logging import time @@ -17,6 +18,10 @@ from certbot import interfaces logger = logging.getLogger(__name__) +AnnotatedAuthzr = collections.namedtuple("AnnotatedAuthzr", ["authzr", "achalls"]) +"""Stores an authorization resource and its active annotated challenges.""" + + class AuthHandler(object): """ACME Authorization Handler for a client. @@ -29,10 +34,8 @@ class AuthHandler(object): :ivar account: Client's Account :type account: :class:`certbot.account.Account` - :ivar dict authzr: ACME Authorization Resource dict where keys are domains - and values are :class:`acme.messages.AuthorizationResource` - :ivar list achalls: DV challenges in the form of - :class:`certbot.achallenges.AnnotatedChallenge` + :ivar aauthzrs: ACME Authorization Resources and their active challenges + :type aauthzrs: `list` of `AnnotatedAuthzr` :ivar list pref_challs: sorted user specified preferred challenges type strings with the most preferred challenge listed first @@ -42,12 +45,9 @@ class AuthHandler(object): self.acme = acme self.account = account - self.authzr = dict() + self.aauthzrs = [] self.pref_challs = pref_challs - # List must be used to keep responses straight. - self.achalls = [] - def handle_authorizations(self, orderr, best_effort=False): """Retrieve all authorizations for challenges. @@ -63,17 +63,15 @@ class AuthHandler(object): authorizations """ - authzrs = orderr.authorizations - for authzr in authzrs: - self.authzr[authzr.body.identifier.value] = authzr - domains = self.authzr.keys() + for authzr in orderr.authorizations: + self.aauthzrs.append(AnnotatedAuthzr(authzr, [])) - self._choose_challenges(domains) + self._choose_challenges() config = zope.component.getUtility(interfaces.IConfig) notify = zope.component.getUtility(interfaces.IDisplay).notification # While there are still challenges remaining... - while self.achalls: + while self._has_challenges(): resp = self._solve_challenges() logger.info("Waiting for verification...") if config.debug_challenges: @@ -87,8 +85,8 @@ class AuthHandler(object): self.verify_authzr_complete() # Only return valid authorizations - retVal = [authzr for authzr in self.authzr.values() - if authzr.body.status == messages.STATUS_VALID] + retVal = [aauthzr.authzr for aauthzr in self.aauthzrs + if aauthzr.authzr.body.status == messages.STATUS_VALID] if not retVal: raise errors.AuthorizationError( @@ -96,41 +94,54 @@ class AuthHandler(object): return retVal - def _choose_challenges(self, domains): + def _choose_challenges(self): """Retrieve necessary challenges to satisfy server.""" logger.info("Performing the following challenges:") - for dom in domains: - dom_challenges = self.authzr[dom].body.challenges + for aauthzr in self.aauthzrs: + aauthzr_challenges = aauthzr.authzr.body.challenges if self.acme.acme_version == 1: - combinations = self.authzr[dom].body.combinations + combinations = aauthzr.authzr.body.combinations else: - combinations = tuple((i,) for i in range(len(dom_challenges))) + combinations = tuple((i,) for i in range(len(aauthzr_challenges))) path = gen_challenge_path( - dom_challenges, - self._get_chall_pref(dom), + aauthzr_challenges, + self._get_chall_pref(aauthzr.authzr.body.identifier.value), combinations) - dom_achalls = self._challenge_factory( - dom, path) - self.achalls.extend(dom_achalls) + aauthzr_achalls = self._challenge_factory( + aauthzr.authzr, path) + aauthzr.achalls.extend(aauthzr_achalls) + + def _has_challenges(self): + """Do we have any challenges to perform?""" + return any(aauthzr.achalls for aauthzr in self.aauthzrs) def _solve_challenges(self): """Get Responses for challenges from authenticators.""" resp = [] + all_achalls = self._get_all_achalls() with error_handler.ErrorHandler(self._cleanup_challenges): try: - if self.achalls: - resp = self.auth.perform(self.achalls) + if all_achalls: + resp = self.auth.perform(all_achalls) except errors.AuthorizationError: logger.critical("Failure in setting up challenges.") logger.info("Attempting to clean up outstanding challenges...") raise - assert len(resp) == len(self.achalls) + assert len(resp) == len(all_achalls) return resp + def _get_all_achalls(self): + """Return all active challenges.""" + all_achalls = [] + for aauthzr in self.aauthzrs: + all_achalls.extend(aauthzr.achalls) + + return all_achalls + def _respond(self, resp, best_effort): """Send/Receive confirmation of all challenges. @@ -139,69 +150,67 @@ class AuthHandler(object): """ # TODO: chall_update is a dirty hack to get around acme-spec #105 chall_update = dict() - active_achalls = self._send_responses(self.achalls, - resp, chall_update) + active_achalls = self._send_responses(resp, chall_update) # Check for updated status... try: self._poll_challenges(chall_update, best_effort) finally: - # This removes challenges from self.achalls self._cleanup_challenges(active_achalls) - def _send_responses(self, achalls, resps, chall_update): + def _send_responses(self, resps, chall_update): """Send responses and make sure errors are handled. :param dict chall_update: parameter that is updated to hold - authzr -> list of outstanding solved annotated challenges + aauthzr index to list of outstanding solved annotated challenges """ active_achalls = [] - for achall, resp in six.moves.zip(achalls, resps): - # This line needs to be outside of the if block below to - # ensure failed challenges are cleaned up correctly - active_achalls.append(achall) + resps_iter = iter(resps) + for i, aauthzr in enumerate(self.aauthzrs): + for achall in aauthzr.achalls: + # This line needs to be outside of the if block below to + # ensure failed challenges are cleaned up correctly + active_achalls.append(achall) - # Don't send challenges for None and False authenticator responses - if resp is not None and resp: - self.acme.answer_challenge(achall.challb, resp) - # TODO: answer_challenge returns challr, with URI, - # that can be used in _find_updated_challr - # comparisons... - if achall.domain in chall_update: - chall_update[achall.domain].append(achall) - else: - chall_update[achall.domain] = [achall] + resp = next(resps_iter) + # Don't send challenges for None and False authenticator responses + if resp: + self.acme.answer_challenge(achall.challb, resp) + # TODO: answer_challenge returns challr, with URI, + # that can be used in _find_updated_challr + # comparisons... + chall_update.setdefault(i, []).append(achall) return active_achalls def _poll_challenges( self, chall_update, best_effort, min_sleep=3, max_rounds=15): """Wait for all challenge results to be determined.""" - dom_to_check = set(chall_update.keys()) - comp_domains = set() + indices_to_check = set(chall_update.keys()) + comp_indices = set() rounds = 0 - while dom_to_check and rounds < max_rounds: + while indices_to_check and rounds < max_rounds: # TODO: Use retry-after... time.sleep(min_sleep) all_failed_achalls = set() - for domain in dom_to_check: + for index in indices_to_check: comp_achalls, failed_achalls = self._handle_check( - domain, chall_update[domain]) + index, chall_update[index]) - if len(comp_achalls) == len(chall_update[domain]): - comp_domains.add(domain) + if len(comp_achalls) == len(chall_update[index]): + comp_indices.add(index) elif not failed_achalls: for achall, _ in comp_achalls: - chall_update[domain].remove(achall) + chall_update[index].remove(achall) # We failed some challenges... damage control else: if best_effort: - comp_domains.add(domain) + comp_indices.add(index) logger.warning( "Challenge failed for domain %s", - domain) + self.aauthzrs[index].authzr.body.identifier.value) else: all_failed_achalls.update( updated for _, updated in failed_achalls) @@ -210,24 +219,26 @@ class AuthHandler(object): _report_failed_challs(all_failed_achalls) raise errors.FailedChallenges(all_failed_achalls) - dom_to_check -= comp_domains - comp_domains.clear() + indices_to_check -= comp_indices + comp_indices.clear() rounds += 1 - def _handle_check(self, domain, achalls): + def _handle_check(self, index, achalls): """Returns tuple of ('completed', 'failed').""" completed = [] failed = [] - self.authzr[domain], _ = self.acme.poll(self.authzr[domain]) - if self.authzr[domain].body.status == messages.STATUS_VALID: + original_aauthzr = self.aauthzrs[index] + updated_authzr, _ = self.acme.poll(original_aauthzr.authzr) + self.aauthzrs[index] = AnnotatedAuthzr(updated_authzr, original_aauthzr.achalls) + if updated_authzr.body.status == messages.STATUS_VALID: return achalls, [] # Note: if the whole authorization is invalid, the individual failed # challenges will be determined here... for achall in achalls: updated_achall = achall.update(challb=self._find_updated_challb( - self.authzr[domain], achall)) + updated_authzr, achall)) # This does nothing for challenges that have yet to be decided yet. if updated_achall.status == messages.STATUS_VALID: @@ -285,14 +296,17 @@ class AuthHandler(object): logger.info("Cleaning up challenges") if achall_list is None: - achalls = self.achalls + achalls = self._get_all_achalls() else: achalls = achall_list if achalls: self.auth.cleanup(achalls) for achall in achalls: - self.achalls.remove(achall) + for aauthzr in self.aauthzrs: + if achall in aauthzr.achalls: + aauthzr.achalls.remove(achall) + break def verify_authzr_complete(self): """Verifies that all authorizations have been decided. @@ -301,15 +315,16 @@ class AuthHandler(object): :rtype: bool """ - for authzr in self.authzr.values(): + for aauthzr in self.aauthzrs: + authzr = aauthzr.authzr if (authzr.body.status != messages.STATUS_VALID and authzr.body.status != messages.STATUS_INVALID): raise errors.AuthorizationError("Incomplete authorizations") - def _challenge_factory(self, domain, path): + def _challenge_factory(self, authzr, path): """Construct Namedtuple Challenges - :param str domain: domain of the enrollee + :param messages.AuthorizationResource authzr: authorization :param list path: List of indices from `challenges`. @@ -323,8 +338,9 @@ class AuthHandler(object): achalls = [] for index in path: - challb = self.authzr[domain].body.challenges[index] - achalls.append(challb_to_achall(challb, self.account.key, domain)) + challb = authzr.body.challenges[index] + achalls.append(challb_to_achall( + challb, self.account.key, authzr.body.identifier.value)) return achalls diff --git a/certbot/cli.py b/certbot/cli.py index 09dd71d13..1c2273c8a 100644 --- a/certbot/cli.py +++ b/certbot/cli.py @@ -599,6 +599,11 @@ class HelpfulArgumentParser(object): if parsed_args.validate_hooks: hooks.validate_hooks(parsed_args) + if parsed_args.allow_subset_of_names: + if any(util.is_wildcard_domain(d) for d in parsed_args.domains): + raise errors.Error("Using --allow-subset-of-names with a" + " wildcard domain is not supported.") + possible_deprecation_warning(parsed_args) return parsed_args diff --git a/certbot/client.py b/certbot/client.py index 0f4fa760d..81fc0b802 100644 --- a/certbot/client.py +++ b/certbot/client.py @@ -298,7 +298,12 @@ class Client(object): auth_domains = set(a.body.identifier.value for a in authzr) successful_domains = [d for d in domains if d in auth_domains] - if successful_domains != domains: + # allow_subset_of_names is currently disabled for wildcard + # certificates. The reason for this and checking allow_subset_of_names + # below is because successful_domains == domains is never true if + # domains contains a wildcard because the ACME spec forbids identifiers + # in authzs from containing a wildcard character. + if self.config.allow_subset_of_names and successful_domains != domains: if not self.config.dry_run: os.remove(key.file) os.remove(csr.file) diff --git a/certbot/tests/auth_handler_test.py b/certbot/tests/auth_handler_test.py index 394002206..7650d2c95 100644 --- a/certbot/tests/auth_handler_test.py +++ b/certbot/tests/auth_handler_test.py @@ -29,32 +29,31 @@ class ChallengeFactoryTest(unittest.TestCase): # Account is mocked... self.handler = AuthHandler(None, None, mock.Mock(key="mock_key"), []) - self.dom = "test" - self.handler.authzr[self.dom] = acme_util.gen_authzr( - messages.STATUS_PENDING, self.dom, acme_util.CHALLENGES, + self.authzr = acme_util.gen_authzr( + messages.STATUS_PENDING, "test", acme_util.CHALLENGES, [messages.STATUS_PENDING] * 6, False) def test_all(self): achalls = self.handler._challenge_factory( - self.dom, range(0, len(acme_util.CHALLENGES))) + self.authzr, range(0, len(acme_util.CHALLENGES))) self.assertEqual( [achall.chall for achall in achalls], acme_util.CHALLENGES) def test_one_tls_sni(self): - achalls = self.handler._challenge_factory(self.dom, [1]) + achalls = self.handler._challenge_factory(self.authzr, [1]) self.assertEqual( [achall.chall for achall in achalls], [acme_util.TLSSNI01]) def test_unrecognized(self): - self.handler.authzr["failure.com"] = acme_util.gen_authzr( - messages.STATUS_PENDING, "failure.com", - [mock.Mock(chall="chall", typ="unrecognized")], - [messages.STATUS_PENDING]) + authzr = acme_util.gen_authzr( + messages.STATUS_PENDING, "test", + [mock.Mock(chall="chall", typ="unrecognized")], + [messages.STATUS_PENDING]) self.assertRaises( - errors.Error, self.handler._challenge_factory, "failure.com", [0]) + errors.Error, self.handler._challenge_factory, authzr, [0]) class HandleAuthorizationsTest(unittest.TestCase): @@ -103,7 +102,7 @@ class HandleAuthorizationsTest(unittest.TestCase): self.assertEqual(mock_poll.call_count, 1) chall_update = mock_poll.call_args[0][0] - self.assertEqual(list(six.iterkeys(chall_update)), ["0"]) + self.assertEqual(list(six.iterkeys(chall_update)), [0]) self.assertEqual(len(chall_update.values()), 1) self.assertEqual(self.mock_auth.cleanup.call_count, 1) @@ -134,7 +133,7 @@ class HandleAuthorizationsTest(unittest.TestCase): self.assertEqual(mock_poll.call_count, 1) chall_update = mock_poll.call_args[0][0] - self.assertEqual(list(six.iterkeys(chall_update)), ["0"]) + self.assertEqual(list(six.iterkeys(chall_update)), [0]) self.assertEqual(len(chall_update.values()), 1) self.assertEqual(self.mock_auth.cleanup.call_count, 1) @@ -160,7 +159,7 @@ class HandleAuthorizationsTest(unittest.TestCase): self.assertEqual(mock_poll.call_count, 1) chall_update = mock_poll.call_args[0][0] - self.assertEqual(list(six.iterkeys(chall_update)), ["0"]) + self.assertEqual(list(six.iterkeys(chall_update)), [0]) self.assertEqual(len(chall_update.values()), 1) self.assertEqual(self.mock_auth.cleanup.call_count, 1) @@ -190,12 +189,12 @@ class HandleAuthorizationsTest(unittest.TestCase): self.assertEqual(mock_poll.call_count, 1) chall_update = mock_poll.call_args[0][0] self.assertEqual(len(list(six.iterkeys(chall_update))), 3) - self.assertTrue("0" in list(six.iterkeys(chall_update))) - self.assertEqual(len(chall_update["0"]), 1) - self.assertTrue("1" in list(six.iterkeys(chall_update))) - self.assertEqual(len(chall_update["1"]), 1) - self.assertTrue("2" in list(six.iterkeys(chall_update))) - self.assertEqual(len(chall_update["2"]), 1) + self.assertTrue(0 in list(six.iterkeys(chall_update))) + self.assertEqual(len(chall_update[0]), 1) + self.assertTrue(1 in list(six.iterkeys(chall_update))) + self.assertEqual(len(chall_update[1]), 1) + self.assertTrue(2 in list(six.iterkeys(chall_update))) + self.assertEqual(len(chall_update[2]), 1) self.assertEqual(self.mock_auth.cleanup.call_count, 1) @@ -274,14 +273,15 @@ class HandleAuthorizationsTest(unittest.TestCase): self._test_preferred_challenges_not_supported_common(combos=False) def _validate_all(self, unused_1, unused_2): - for dom in six.iterkeys(self.handler.authzr): - azr = self.handler.authzr[dom] - self.handler.authzr[dom] = acme_util.gen_authzr( + for i, aauthzr in enumerate(self.handler.aauthzrs): + azr = aauthzr.authzr + updated_azr = acme_util.gen_authzr( messages.STATUS_VALID, - dom, + azr.body.identifier.value, [challb.chall for challb in azr.body.challenges], [messages.STATUS_VALID] * len(azr.body.challenges), azr.body.combinations) + self.handler.aauthzrs[i] = type(aauthzr)(updated_azr, aauthzr.achalls) class PollChallengesTest(unittest.TestCase): @@ -290,7 +290,7 @@ class PollChallengesTest(unittest.TestCase): def setUp(self): from certbot.auth_handler import challb_to_achall - from certbot.auth_handler import AuthHandler + from certbot.auth_handler import AuthHandler, AnnotatedAuthzr # Account and network are mocked... self.mock_net = mock.MagicMock() @@ -298,40 +298,38 @@ class PollChallengesTest(unittest.TestCase): None, self.mock_net, mock.Mock(key="mock_key"), []) self.doms = ["0", "1", "2"] - self.handler.authzr[self.doms[0]] = acme_util.gen_authzr( + self.handler.aauthzrs.append(AnnotatedAuthzr(acme_util.gen_authzr( messages.STATUS_PENDING, self.doms[0], [acme_util.HTTP01, acme_util.TLSSNI01], - [messages.STATUS_PENDING] * 2, False) - - self.handler.authzr[self.doms[1]] = acme_util.gen_authzr( + [messages.STATUS_PENDING] * 2, False), [])) + self.handler.aauthzrs.append(AnnotatedAuthzr(acme_util.gen_authzr( messages.STATUS_PENDING, self.doms[1], - acme_util.CHALLENGES, [messages.STATUS_PENDING] * 3, False) - - self.handler.authzr[self.doms[2]] = acme_util.gen_authzr( + acme_util.CHALLENGES, [messages.STATUS_PENDING] * 3, False), [])) + self.handler.aauthzrs.append(AnnotatedAuthzr(acme_util.gen_authzr( messages.STATUS_PENDING, self.doms[2], - acme_util.CHALLENGES, [messages.STATUS_PENDING] * 3, False) + acme_util.CHALLENGES, [messages.STATUS_PENDING] * 3, False), [])) self.chall_update = {} - for dom in self.doms: - self.chall_update[dom] = [ - challb_to_achall(challb, mock.Mock(key="dummy_key"), dom) - for challb in self.handler.authzr[dom].body.challenges] + for i, aauthzr in enumerate(self.handler.aauthzrs): + self.chall_update[i] = [ + challb_to_achall(challb, mock.Mock(key="dummy_key"), self.doms[i]) + for challb in aauthzr.authzr.body.challenges] @mock.patch("certbot.auth_handler.time") def test_poll_challenges(self, unused_mock_time): self.mock_net.poll.side_effect = self._mock_poll_solve_one_valid self.handler._poll_challenges(self.chall_update, False) - for authzr in self.handler.authzr.values(): - self.assertEqual(authzr.body.status, messages.STATUS_VALID) + for aauthzr in self.handler.aauthzrs: + self.assertEqual(aauthzr.authzr.body.status, messages.STATUS_VALID) @mock.patch("certbot.auth_handler.time") def test_poll_challenges_failure_best_effort(self, unused_mock_time): self.mock_net.poll.side_effect = self._mock_poll_solve_one_invalid self.handler._poll_challenges(self.chall_update, True) - for authzr in self.handler.authzr.values(): - self.assertEqual(authzr.body.status, messages.STATUS_PENDING) + for aauthzr in self.handler.aauthzrs: + self.assertEqual(aauthzr.authzr.body.status, messages.STATUS_PENDING) @mock.patch("certbot.auth_handler.time") @test_util.patch_get_utility() @@ -345,7 +343,7 @@ class PollChallengesTest(unittest.TestCase): def test_unable_to_find_challenge_status(self, unused_mock_time): from certbot.auth_handler import challb_to_achall self.mock_net.poll.side_effect = self._mock_poll_solve_one_valid - self.chall_update[self.doms[0]].append( + self.chall_update[0].append( challb_to_achall(acme_util.DNS01_P, "key", self.doms[0])) self.assertRaises( errors.AuthorizationError, self.handler._poll_challenges, diff --git a/certbot/tests/cli_test.py b/certbot/tests/cli_test.py index c5935d722..1bba6991a 100644 --- a/certbot/tests/cli_test.py +++ b/certbot/tests/cli_test.py @@ -426,6 +426,10 @@ class ParseTest(unittest.TestCase): # pylint: disable=too-many-public-methods namespace = self.parse(["--no-delete-after-revoke"]) self.assertFalse(namespace.delete_after_revoke) + def test_allow_subset_with_wildcard(self): + self.assertRaises(errors.Error, self.parse, + "--allow-subset-of-names -d *.example.org".split()) + class DefaultTest(unittest.TestCase): """Tests for certbot.cli._Default.""" diff --git a/certbot/tests/client_test.py b/certbot/tests/client_test.py index ed9c140e7..b51275d9e 100644 --- a/certbot/tests/client_test.py +++ b/certbot/tests/client_test.py @@ -222,6 +222,7 @@ class ClientTest(ClientTestCommon): mock.sentinel.chain) authzr = self._authzr_from_domains(["example.com"]) + self.config.allow_subset_of_names = True self._test_obtain_certificate_common(key, csr, authzr_ret=authzr, auth_count=2) self.assertEqual(mock_crypto_util.init_save_key.call_count, 2) diff --git a/certbot/tests/util_test.py b/certbot/tests/util_test.py index 50d323ffd..0e280f3ab 100644 --- a/certbot/tests/util_test.py +++ b/certbot/tests/util_test.py @@ -487,6 +487,26 @@ class EnforceDomainSanityTest(unittest.TestCase): self._call('this.is.xn--ls8h.tld') +class IsWildcardDomainTest(unittest.TestCase): + """Tests for is_wildcard_domain.""" + + def setUp(self): + self.wildcard = u"*.example.org" + self.no_wildcard = u"example.org" + + def _call(self, domain): + from certbot.util import is_wildcard_domain + return is_wildcard_domain(domain) + + def test_no_wildcard(self): + self.assertFalse(self._call(self.no_wildcard)) + self.assertFalse(self._call(self.no_wildcard.encode())) + + def test_wildcard(self): + self.assertTrue(self._call(self.wildcard)) + self.assertTrue(self._call(self.wildcard.encode())) + + class OsInfoTest(unittest.TestCase): """Test OS / distribution detection""" diff --git a/certbot/util.py b/certbot/util.py index f47c5da9c..f7ce6a3bc 100644 --- a/certbot/util.py +++ b/certbot/util.py @@ -601,6 +601,24 @@ def enforce_domain_sanity(domain): return domain +def is_wildcard_domain(domain): + """"Is domain a wildcard domain? + + :param damain: domain to check + :type domain: `bytes` or `str` or `unicode` + + :returns: True if domain is a wildcard, otherwise, False + :rtype: bool + + """ + if isinstance(domain, six.text_type): + wildcard_marker = u"*." + else: + wildcard_marker = b"*." + + return domain.startswith(wildcard_marker) + + def get_strict_version(normalized): """Converts a normalized version to a strict version.