Merge branch 'master' into v2-integration

This commit is contained in:
Brad Warren 2018-02-28 08:02:02 -08:00
commit 63fc31eec4
8 changed files with 180 additions and 113 deletions

View file

@ -1,4 +1,5 @@
"""ACME AuthHandler."""
import collections
import logging
import time
@ -17,6 +18,10 @@ from certbot import interfaces
logger = logging.getLogger(__name__)
AnnotatedAuthzr = collections.namedtuple("AnnotatedAuthzr", ["authzr", "achalls"])
"""Stores an authorization resource and its active annotated challenges."""
class AuthHandler(object):
"""ACME Authorization Handler for a client.
@ -29,10 +34,8 @@ class AuthHandler(object):
:ivar account: Client's Account
:type account: :class:`certbot.account.Account`
:ivar dict authzr: ACME Authorization Resource dict where keys are domains
and values are :class:`acme.messages.AuthorizationResource`
:ivar list achalls: DV challenges in the form of
:class:`certbot.achallenges.AnnotatedChallenge`
:ivar aauthzrs: ACME Authorization Resources and their active challenges
:type aauthzrs: `list` of `AnnotatedAuthzr`
:ivar list pref_challs: sorted user specified preferred challenges
type strings with the most preferred challenge listed first
@ -42,12 +45,9 @@ class AuthHandler(object):
self.acme = acme
self.account = account
self.authzr = dict()
self.aauthzrs = []
self.pref_challs = pref_challs
# List must be used to keep responses straight.
self.achalls = []
def handle_authorizations(self, orderr, best_effort=False):
"""Retrieve all authorizations for challenges.
@ -63,17 +63,15 @@ class AuthHandler(object):
authorizations
"""
authzrs = orderr.authorizations
for authzr in authzrs:
self.authzr[authzr.body.identifier.value] = authzr
domains = self.authzr.keys()
for authzr in orderr.authorizations:
self.aauthzrs.append(AnnotatedAuthzr(authzr, []))
self._choose_challenges(domains)
self._choose_challenges()
config = zope.component.getUtility(interfaces.IConfig)
notify = zope.component.getUtility(interfaces.IDisplay).notification
# While there are still challenges remaining...
while self.achalls:
while self._has_challenges():
resp = self._solve_challenges()
logger.info("Waiting for verification...")
if config.debug_challenges:
@ -87,8 +85,8 @@ class AuthHandler(object):
self.verify_authzr_complete()
# Only return valid authorizations
retVal = [authzr for authzr in self.authzr.values()
if authzr.body.status == messages.STATUS_VALID]
retVal = [aauthzr.authzr for aauthzr in self.aauthzrs
if aauthzr.authzr.body.status == messages.STATUS_VALID]
if not retVal:
raise errors.AuthorizationError(
@ -96,41 +94,54 @@ class AuthHandler(object):
return retVal
def _choose_challenges(self, domains):
def _choose_challenges(self):
"""Retrieve necessary challenges to satisfy server."""
logger.info("Performing the following challenges:")
for dom in domains:
dom_challenges = self.authzr[dom].body.challenges
for aauthzr in self.aauthzrs:
aauthzr_challenges = aauthzr.authzr.body.challenges
if self.acme.acme_version == 1:
combinations = self.authzr[dom].body.combinations
combinations = aauthzr.authzr.body.combinations
else:
combinations = tuple((i,) for i in range(len(dom_challenges)))
combinations = tuple((i,) for i in range(len(aauthzr_challenges)))
path = gen_challenge_path(
dom_challenges,
self._get_chall_pref(dom),
aauthzr_challenges,
self._get_chall_pref(aauthzr.authzr.body.identifier.value),
combinations)
dom_achalls = self._challenge_factory(
dom, path)
self.achalls.extend(dom_achalls)
aauthzr_achalls = self._challenge_factory(
aauthzr.authzr, path)
aauthzr.achalls.extend(aauthzr_achalls)
def _has_challenges(self):
"""Do we have any challenges to perform?"""
return any(aauthzr.achalls for aauthzr in self.aauthzrs)
def _solve_challenges(self):
"""Get Responses for challenges from authenticators."""
resp = []
all_achalls = self._get_all_achalls()
with error_handler.ErrorHandler(self._cleanup_challenges):
try:
if self.achalls:
resp = self.auth.perform(self.achalls)
if all_achalls:
resp = self.auth.perform(all_achalls)
except errors.AuthorizationError:
logger.critical("Failure in setting up challenges.")
logger.info("Attempting to clean up outstanding challenges...")
raise
assert len(resp) == len(self.achalls)
assert len(resp) == len(all_achalls)
return resp
def _get_all_achalls(self):
"""Return all active challenges."""
all_achalls = []
for aauthzr in self.aauthzrs:
all_achalls.extend(aauthzr.achalls)
return all_achalls
def _respond(self, resp, best_effort):
"""Send/Receive confirmation of all challenges.
@ -139,69 +150,67 @@ class AuthHandler(object):
"""
# TODO: chall_update is a dirty hack to get around acme-spec #105
chall_update = dict()
active_achalls = self._send_responses(self.achalls,
resp, chall_update)
active_achalls = self._send_responses(resp, chall_update)
# Check for updated status...
try:
self._poll_challenges(chall_update, best_effort)
finally:
# This removes challenges from self.achalls
self._cleanup_challenges(active_achalls)
def _send_responses(self, achalls, resps, chall_update):
def _send_responses(self, resps, chall_update):
"""Send responses and make sure errors are handled.
:param dict chall_update: parameter that is updated to hold
authzr -> list of outstanding solved annotated challenges
aauthzr index to list of outstanding solved annotated challenges
"""
active_achalls = []
for achall, resp in six.moves.zip(achalls, resps):
# This line needs to be outside of the if block below to
# ensure failed challenges are cleaned up correctly
active_achalls.append(achall)
resps_iter = iter(resps)
for i, aauthzr in enumerate(self.aauthzrs):
for achall in aauthzr.achalls:
# This line needs to be outside of the if block below to
# ensure failed challenges are cleaned up correctly
active_achalls.append(achall)
# Don't send challenges for None and False authenticator responses
if resp is not None and resp:
self.acme.answer_challenge(achall.challb, resp)
# TODO: answer_challenge returns challr, with URI,
# that can be used in _find_updated_challr
# comparisons...
if achall.domain in chall_update:
chall_update[achall.domain].append(achall)
else:
chall_update[achall.domain] = [achall]
resp = next(resps_iter)
# Don't send challenges for None and False authenticator responses
if resp:
self.acme.answer_challenge(achall.challb, resp)
# TODO: answer_challenge returns challr, with URI,
# that can be used in _find_updated_challr
# comparisons...
chall_update.setdefault(i, []).append(achall)
return active_achalls
def _poll_challenges(
self, chall_update, best_effort, min_sleep=3, max_rounds=15):
"""Wait for all challenge results to be determined."""
dom_to_check = set(chall_update.keys())
comp_domains = set()
indices_to_check = set(chall_update.keys())
comp_indices = set()
rounds = 0
while dom_to_check and rounds < max_rounds:
while indices_to_check and rounds < max_rounds:
# TODO: Use retry-after...
time.sleep(min_sleep)
all_failed_achalls = set()
for domain in dom_to_check:
for index in indices_to_check:
comp_achalls, failed_achalls = self._handle_check(
domain, chall_update[domain])
index, chall_update[index])
if len(comp_achalls) == len(chall_update[domain]):
comp_domains.add(domain)
if len(comp_achalls) == len(chall_update[index]):
comp_indices.add(index)
elif not failed_achalls:
for achall, _ in comp_achalls:
chall_update[domain].remove(achall)
chall_update[index].remove(achall)
# We failed some challenges... damage control
else:
if best_effort:
comp_domains.add(domain)
comp_indices.add(index)
logger.warning(
"Challenge failed for domain %s",
domain)
self.aauthzrs[index].authzr.body.identifier.value)
else:
all_failed_achalls.update(
updated for _, updated in failed_achalls)
@ -210,24 +219,26 @@ class AuthHandler(object):
_report_failed_challs(all_failed_achalls)
raise errors.FailedChallenges(all_failed_achalls)
dom_to_check -= comp_domains
comp_domains.clear()
indices_to_check -= comp_indices
comp_indices.clear()
rounds += 1
def _handle_check(self, domain, achalls):
def _handle_check(self, index, achalls):
"""Returns tuple of ('completed', 'failed')."""
completed = []
failed = []
self.authzr[domain], _ = self.acme.poll(self.authzr[domain])
if self.authzr[domain].body.status == messages.STATUS_VALID:
original_aauthzr = self.aauthzrs[index]
updated_authzr, _ = self.acme.poll(original_aauthzr.authzr)
self.aauthzrs[index] = AnnotatedAuthzr(updated_authzr, original_aauthzr.achalls)
if updated_authzr.body.status == messages.STATUS_VALID:
return achalls, []
# Note: if the whole authorization is invalid, the individual failed
# challenges will be determined here...
for achall in achalls:
updated_achall = achall.update(challb=self._find_updated_challb(
self.authzr[domain], achall))
updated_authzr, achall))
# This does nothing for challenges that have yet to be decided yet.
if updated_achall.status == messages.STATUS_VALID:
@ -285,14 +296,17 @@ class AuthHandler(object):
logger.info("Cleaning up challenges")
if achall_list is None:
achalls = self.achalls
achalls = self._get_all_achalls()
else:
achalls = achall_list
if achalls:
self.auth.cleanup(achalls)
for achall in achalls:
self.achalls.remove(achall)
for aauthzr in self.aauthzrs:
if achall in aauthzr.achalls:
aauthzr.achalls.remove(achall)
break
def verify_authzr_complete(self):
"""Verifies that all authorizations have been decided.
@ -301,15 +315,16 @@ class AuthHandler(object):
:rtype: bool
"""
for authzr in self.authzr.values():
for aauthzr in self.aauthzrs:
authzr = aauthzr.authzr
if (authzr.body.status != messages.STATUS_VALID and
authzr.body.status != messages.STATUS_INVALID):
raise errors.AuthorizationError("Incomplete authorizations")
def _challenge_factory(self, domain, path):
def _challenge_factory(self, authzr, path):
"""Construct Namedtuple Challenges
:param str domain: domain of the enrollee
:param messages.AuthorizationResource authzr: authorization
:param list path: List of indices from `challenges`.
@ -323,8 +338,9 @@ class AuthHandler(object):
achalls = []
for index in path:
challb = self.authzr[domain].body.challenges[index]
achalls.append(challb_to_achall(challb, self.account.key, domain))
challb = authzr.body.challenges[index]
achalls.append(challb_to_achall(
challb, self.account.key, authzr.body.identifier.value))
return achalls

View file

@ -599,6 +599,11 @@ class HelpfulArgumentParser(object):
if parsed_args.validate_hooks:
hooks.validate_hooks(parsed_args)
if parsed_args.allow_subset_of_names:
if any(util.is_wildcard_domain(d) for d in parsed_args.domains):
raise errors.Error("Using --allow-subset-of-names with a"
" wildcard domain is not supported.")
possible_deprecation_warning(parsed_args)
return parsed_args

View file

@ -298,7 +298,12 @@ class Client(object):
auth_domains = set(a.body.identifier.value for a in authzr)
successful_domains = [d for d in domains if d in auth_domains]
if successful_domains != domains:
# allow_subset_of_names is currently disabled for wildcard
# certificates. The reason for this and checking allow_subset_of_names
# below is because successful_domains == domains is never true if
# domains contains a wildcard because the ACME spec forbids identifiers
# in authzs from containing a wildcard character.
if self.config.allow_subset_of_names and successful_domains != domains:
if not self.config.dry_run:
os.remove(key.file)
os.remove(csr.file)

View file

@ -29,32 +29,31 @@ class ChallengeFactoryTest(unittest.TestCase):
# Account is mocked...
self.handler = AuthHandler(None, None, mock.Mock(key="mock_key"), [])
self.dom = "test"
self.handler.authzr[self.dom] = acme_util.gen_authzr(
messages.STATUS_PENDING, self.dom, acme_util.CHALLENGES,
self.authzr = acme_util.gen_authzr(
messages.STATUS_PENDING, "test", acme_util.CHALLENGES,
[messages.STATUS_PENDING] * 6, False)
def test_all(self):
achalls = self.handler._challenge_factory(
self.dom, range(0, len(acme_util.CHALLENGES)))
self.authzr, range(0, len(acme_util.CHALLENGES)))
self.assertEqual(
[achall.chall for achall in achalls], acme_util.CHALLENGES)
def test_one_tls_sni(self):
achalls = self.handler._challenge_factory(self.dom, [1])
achalls = self.handler._challenge_factory(self.authzr, [1])
self.assertEqual(
[achall.chall for achall in achalls], [acme_util.TLSSNI01])
def test_unrecognized(self):
self.handler.authzr["failure.com"] = acme_util.gen_authzr(
messages.STATUS_PENDING, "failure.com",
[mock.Mock(chall="chall", typ="unrecognized")],
[messages.STATUS_PENDING])
authzr = acme_util.gen_authzr(
messages.STATUS_PENDING, "test",
[mock.Mock(chall="chall", typ="unrecognized")],
[messages.STATUS_PENDING])
self.assertRaises(
errors.Error, self.handler._challenge_factory, "failure.com", [0])
errors.Error, self.handler._challenge_factory, authzr, [0])
class HandleAuthorizationsTest(unittest.TestCase):
@ -103,7 +102,7 @@ class HandleAuthorizationsTest(unittest.TestCase):
self.assertEqual(mock_poll.call_count, 1)
chall_update = mock_poll.call_args[0][0]
self.assertEqual(list(six.iterkeys(chall_update)), ["0"])
self.assertEqual(list(six.iterkeys(chall_update)), [0])
self.assertEqual(len(chall_update.values()), 1)
self.assertEqual(self.mock_auth.cleanup.call_count, 1)
@ -134,7 +133,7 @@ class HandleAuthorizationsTest(unittest.TestCase):
self.assertEqual(mock_poll.call_count, 1)
chall_update = mock_poll.call_args[0][0]
self.assertEqual(list(six.iterkeys(chall_update)), ["0"])
self.assertEqual(list(six.iterkeys(chall_update)), [0])
self.assertEqual(len(chall_update.values()), 1)
self.assertEqual(self.mock_auth.cleanup.call_count, 1)
@ -160,7 +159,7 @@ class HandleAuthorizationsTest(unittest.TestCase):
self.assertEqual(mock_poll.call_count, 1)
chall_update = mock_poll.call_args[0][0]
self.assertEqual(list(six.iterkeys(chall_update)), ["0"])
self.assertEqual(list(six.iterkeys(chall_update)), [0])
self.assertEqual(len(chall_update.values()), 1)
self.assertEqual(self.mock_auth.cleanup.call_count, 1)
@ -190,12 +189,12 @@ class HandleAuthorizationsTest(unittest.TestCase):
self.assertEqual(mock_poll.call_count, 1)
chall_update = mock_poll.call_args[0][0]
self.assertEqual(len(list(six.iterkeys(chall_update))), 3)
self.assertTrue("0" in list(six.iterkeys(chall_update)))
self.assertEqual(len(chall_update["0"]), 1)
self.assertTrue("1" in list(six.iterkeys(chall_update)))
self.assertEqual(len(chall_update["1"]), 1)
self.assertTrue("2" in list(six.iterkeys(chall_update)))
self.assertEqual(len(chall_update["2"]), 1)
self.assertTrue(0 in list(six.iterkeys(chall_update)))
self.assertEqual(len(chall_update[0]), 1)
self.assertTrue(1 in list(six.iterkeys(chall_update)))
self.assertEqual(len(chall_update[1]), 1)
self.assertTrue(2 in list(six.iterkeys(chall_update)))
self.assertEqual(len(chall_update[2]), 1)
self.assertEqual(self.mock_auth.cleanup.call_count, 1)
@ -274,14 +273,15 @@ class HandleAuthorizationsTest(unittest.TestCase):
self._test_preferred_challenges_not_supported_common(combos=False)
def _validate_all(self, unused_1, unused_2):
for dom in six.iterkeys(self.handler.authzr):
azr = self.handler.authzr[dom]
self.handler.authzr[dom] = acme_util.gen_authzr(
for i, aauthzr in enumerate(self.handler.aauthzrs):
azr = aauthzr.authzr
updated_azr = acme_util.gen_authzr(
messages.STATUS_VALID,
dom,
azr.body.identifier.value,
[challb.chall for challb in azr.body.challenges],
[messages.STATUS_VALID] * len(azr.body.challenges),
azr.body.combinations)
self.handler.aauthzrs[i] = type(aauthzr)(updated_azr, aauthzr.achalls)
class PollChallengesTest(unittest.TestCase):
@ -290,7 +290,7 @@ class PollChallengesTest(unittest.TestCase):
def setUp(self):
from certbot.auth_handler import challb_to_achall
from certbot.auth_handler import AuthHandler
from certbot.auth_handler import AuthHandler, AnnotatedAuthzr
# Account and network are mocked...
self.mock_net = mock.MagicMock()
@ -298,40 +298,38 @@ class PollChallengesTest(unittest.TestCase):
None, self.mock_net, mock.Mock(key="mock_key"), [])
self.doms = ["0", "1", "2"]
self.handler.authzr[self.doms[0]] = acme_util.gen_authzr(
self.handler.aauthzrs.append(AnnotatedAuthzr(acme_util.gen_authzr(
messages.STATUS_PENDING, self.doms[0],
[acme_util.HTTP01, acme_util.TLSSNI01],
[messages.STATUS_PENDING] * 2, False)
self.handler.authzr[self.doms[1]] = acme_util.gen_authzr(
[messages.STATUS_PENDING] * 2, False), []))
self.handler.aauthzrs.append(AnnotatedAuthzr(acme_util.gen_authzr(
messages.STATUS_PENDING, self.doms[1],
acme_util.CHALLENGES, [messages.STATUS_PENDING] * 3, False)
self.handler.authzr[self.doms[2]] = acme_util.gen_authzr(
acme_util.CHALLENGES, [messages.STATUS_PENDING] * 3, False), []))
self.handler.aauthzrs.append(AnnotatedAuthzr(acme_util.gen_authzr(
messages.STATUS_PENDING, self.doms[2],
acme_util.CHALLENGES, [messages.STATUS_PENDING] * 3, False)
acme_util.CHALLENGES, [messages.STATUS_PENDING] * 3, False), []))
self.chall_update = {}
for dom in self.doms:
self.chall_update[dom] = [
challb_to_achall(challb, mock.Mock(key="dummy_key"), dom)
for challb in self.handler.authzr[dom].body.challenges]
for i, aauthzr in enumerate(self.handler.aauthzrs):
self.chall_update[i] = [
challb_to_achall(challb, mock.Mock(key="dummy_key"), self.doms[i])
for challb in aauthzr.authzr.body.challenges]
@mock.patch("certbot.auth_handler.time")
def test_poll_challenges(self, unused_mock_time):
self.mock_net.poll.side_effect = self._mock_poll_solve_one_valid
self.handler._poll_challenges(self.chall_update, False)
for authzr in self.handler.authzr.values():
self.assertEqual(authzr.body.status, messages.STATUS_VALID)
for aauthzr in self.handler.aauthzrs:
self.assertEqual(aauthzr.authzr.body.status, messages.STATUS_VALID)
@mock.patch("certbot.auth_handler.time")
def test_poll_challenges_failure_best_effort(self, unused_mock_time):
self.mock_net.poll.side_effect = self._mock_poll_solve_one_invalid
self.handler._poll_challenges(self.chall_update, True)
for authzr in self.handler.authzr.values():
self.assertEqual(authzr.body.status, messages.STATUS_PENDING)
for aauthzr in self.handler.aauthzrs:
self.assertEqual(aauthzr.authzr.body.status, messages.STATUS_PENDING)
@mock.patch("certbot.auth_handler.time")
@test_util.patch_get_utility()
@ -345,7 +343,7 @@ class PollChallengesTest(unittest.TestCase):
def test_unable_to_find_challenge_status(self, unused_mock_time):
from certbot.auth_handler import challb_to_achall
self.mock_net.poll.side_effect = self._mock_poll_solve_one_valid
self.chall_update[self.doms[0]].append(
self.chall_update[0].append(
challb_to_achall(acme_util.DNS01_P, "key", self.doms[0]))
self.assertRaises(
errors.AuthorizationError, self.handler._poll_challenges,

View file

@ -426,6 +426,10 @@ class ParseTest(unittest.TestCase): # pylint: disable=too-many-public-methods
namespace = self.parse(["--no-delete-after-revoke"])
self.assertFalse(namespace.delete_after_revoke)
def test_allow_subset_with_wildcard(self):
self.assertRaises(errors.Error, self.parse,
"--allow-subset-of-names -d *.example.org".split())
class DefaultTest(unittest.TestCase):
"""Tests for certbot.cli._Default."""

View file

@ -222,6 +222,7 @@ class ClientTest(ClientTestCommon):
mock.sentinel.chain)
authzr = self._authzr_from_domains(["example.com"])
self.config.allow_subset_of_names = True
self._test_obtain_certificate_common(key, csr, authzr_ret=authzr, auth_count=2)
self.assertEqual(mock_crypto_util.init_save_key.call_count, 2)

View file

@ -487,6 +487,26 @@ class EnforceDomainSanityTest(unittest.TestCase):
self._call('this.is.xn--ls8h.tld')
class IsWildcardDomainTest(unittest.TestCase):
"""Tests for is_wildcard_domain."""
def setUp(self):
self.wildcard = u"*.example.org"
self.no_wildcard = u"example.org"
def _call(self, domain):
from certbot.util import is_wildcard_domain
return is_wildcard_domain(domain)
def test_no_wildcard(self):
self.assertFalse(self._call(self.no_wildcard))
self.assertFalse(self._call(self.no_wildcard.encode()))
def test_wildcard(self):
self.assertTrue(self._call(self.wildcard))
self.assertTrue(self._call(self.wildcard.encode()))
class OsInfoTest(unittest.TestCase):
"""Test OS / distribution detection"""

View file

@ -601,6 +601,24 @@ def enforce_domain_sanity(domain):
return domain
def is_wildcard_domain(domain):
""""Is domain a wildcard domain?
:param damain: domain to check
:type domain: `bytes` or `str` or `unicode`
:returns: True if domain is a wildcard, otherwise, False
:rtype: bool
"""
if isinstance(domain, six.text_type):
wildcard_marker = u"*."
else:
wildcard_marker = b"*."
return domain.startswith(wildcard_marker)
def get_strict_version(normalized):
"""Converts a normalized version to a strict version.