diff --git a/acme/acme/challenges.py b/acme/acme/challenges.py index 57e74144b..5f97547ee 100644 --- a/acme/acme/challenges.py +++ b/acme/acme/challenges.py @@ -187,7 +187,7 @@ class KeyAuthorizationChallenge(_TokenDVChallenge): key_authorization=self.key_authorization(account_key)) @abc.abstractmethod - def validation(self, account_key): + def validation(self, account_key, **kwargs): """Generate validation for the challenge. Subclasses must implement this method, but they are likely to @@ -201,7 +201,7 @@ class KeyAuthorizationChallenge(_TokenDVChallenge): """ raise NotImplementedError() # pragma: no cover - def response_and_validation(self, account_key): + def response_and_validation(self, account_key, *args, **kwargs): """Generate response and validation. Convenience function that return results of `response` and @@ -211,7 +211,8 @@ class KeyAuthorizationChallenge(_TokenDVChallenge): :rtype: tuple """ - return (self.response(account_key), self.validation(account_key)) + return (self.response(account_key), + self.validation(account_key, *args, **kwargs)) @ChallengeResponse.register @@ -308,7 +309,7 @@ class HTTP01(KeyAuthorizationChallenge): """ return "http://" + domain + self.path - def validation(self, account_key): + def validation(self, account_key, **unused_kwargs): """Generate validation. :param JWK account_key: @@ -318,6 +319,127 @@ class HTTP01(KeyAuthorizationChallenge): return self.key_authorization(account_key) +@ChallengeResponse.register +class TLSSNI01Response(KeyAuthorizationChallengeResponse): + """ACME tls-sni-01 challenge response.""" + typ = "tls-sni-01" + + DOMAIN_SUFFIX = b".acme.invalid" + """Domain name suffix.""" + + PORT = 443 + + @property + def z(self): + """``z`` value used for verification.""" + return hashlib.sha256( + self.key_authorization.encode("utf-8")).hexdigest().encode() + + @property + def z_domain(self): + """Domain name used for verification, generated from `z`.""" + return self.z[:32] + b'.' + self.z[32:] + self.DOMAIN_SUFFIX + + def gen_cert(self, key=None, bits=2048): + """Generate tls-sni-01 certificate. + + :param OpenSSL.crypto.PKey key: Optional private key used in + certificate generation. If not provided (``None``), then + fresh key will be generated. + :param int bits: Number of bits for newly generated key. + + :rtype: `tuple` of `OpenSSL.crypto.X509` and `OpenSSL.crypto.PKey` + + """ + if key is None: + key = OpenSSL.crypto.PKey() + key.generate_key(OpenSSL.crypto.TYPE_RSA, bits) + return crypto_util.gen_ss_cert(key, [ + # z_domain is too big to fit into CN, hence first dummy domain + 'dummy', self.z_domain.decode()], force_san=True), key + + def probe_cert(self, domain, **kwargs): + """Probe tls-sni-01 challenge certificate. + + :param unicode domain: + + """ + # TODO: domain is not necessary if host is provided + if "host" not in kwargs: + host = socket.gethostbyname(domain) + logging.debug('%s resolved to %s', domain, host) + kwargs["host"] = host + + kwargs.setdefault("port", self.PORT) + kwargs["name"] = self.z_domain + # TODO: try different methods? + # pylint: disable=protected-access + return crypto_util.probe_sni(**kwargs) + + def verify_cert(self, cert): + """Verify tls-sni-01 challenge certificate.""" + # pylint: disable=protected-access + sans = crypto_util._pyopenssl_cert_or_req_san(cert) + logging.debug('Certificate %s. SANs: %s', cert.digest('sha1'), sans) + return self.z_domain.decode() in sans + + def simple_verify(self, chall, domain, account_public_key, + cert=None, **kwargs): + """Simple verify. + + Verify ``validation`` using ``account_public_key``, optionally + probe tls-sni-01 certificate and check using `verify_cert`. + + :param .challenges.TLSSNI01 chall: Corresponding challenge. + :param str domain: Domain name being validated. + :param JWK account_public_key: + :param OpenSSL.crypto.X509 cert: Optional certificate. If not + provided (``None``) certificate will be retrieved using + `probe_cert`. + + + :returns: ``True`` iff client's control of the domain has been + verified, ``False`` otherwise. + :rtype: bool + + """ + if not self.verify(chall, account_public_key): + logger.debug("Verification of key authorization in response failed") + return False + + if cert is None: + try: + cert = self.probe_cert(domain=domain, **kwargs) + except errors.Error as error: + logger.debug(error, exc_info=True) + return False + + return self.verify_cert(cert) + + +@Challenge.register # pylint: disable=too-many-ancestors +class TLSSNI01(KeyAuthorizationChallenge): + """ACME tls-sni-01 challenge.""" + response_cls = TLSSNI01Response + typ = response_cls.typ + + # boulder#962, ietf-wg-acme#22 + #n = jose.Field("n", encoder=int, decoder=int) + + def validation(self, account_key, **kwargs): + """Generate validation. + + :param JWK account_key: + :param OpenSSL.crypto.PKey cert_key: Optional private key used + in certificate generation. If not provided (``None``), then + fresh key will be generated. + + :rtype: `tuple` of `OpenSSL.crypto.X509` and `OpenSSL.crypto.PKey` + + """ + return self.response(account_key).gen_cert(key=kwargs.get('cert_key')) + + @Challenge.register # pylint: disable=too-many-ancestors class DVSNI(_TokenDVChallenge): """ACME "dvsni" challenge. diff --git a/acme/acme/challenges_test.py b/acme/acme/challenges_test.py index 86291d0e8..3fcb01e4d 100644 --- a/acme/acme/challenges_test.py +++ b/acme/acme/challenges_test.py @@ -186,6 +186,140 @@ class HTTP01Test(unittest.TestCase): self.msg.update(token=b'..').good_token) +class TLSSNI01ResponseTest(unittest.TestCase): + # pylint: disable=too-many-instance-attributes + + def setUp(self): + from acme.challenges import TLSSNI01 + self.chall = TLSSNI01( + token=jose.b64decode(b'a82d5ff8ef740d12881f6d3c2277ab2e')) + + self.response = self.chall.response(KEY) + self.jmsg = { + 'resource': 'challenge', + 'type': 'tls-sni-01', + 'keyAuthorization': self.response.key_authorization, + } + + # pylint: disable=invalid-name + label1 = b'dc38d9c3fa1a4fdcc3a5501f2d38583f' + label2 = b'b7793728f084394f2a1afd459556bb5c' + self.z = label1 + label2 + self.z_domain = label1 + b'.' + label2 + b'.acme.invalid' + self.domain = 'foo.com' + + def test_z_and_domain(self): + self.assertEqual(self.z, self.response.z) + self.assertEqual(self.z_domain, self.response.z_domain) + + def test_to_partial_json(self): + self.assertEqual(self.jmsg, self.response.to_partial_json()) + + def test_from_json(self): + from acme.challenges import TLSSNI01Response + self.assertEqual(self.response, TLSSNI01Response.from_json(self.jmsg)) + + def test_from_json_hashable(self): + from acme.challenges import TLSSNI01Response + hash(TLSSNI01Response.from_json(self.jmsg)) + + @mock.patch('acme.challenges.socket.gethostbyname') + @mock.patch('acme.challenges.crypto_util.probe_sni') + def test_probe_cert(self, mock_probe_sni, mock_gethostbyname): + mock_gethostbyname.return_value = '127.0.0.1' + self.response.probe_cert('foo.com') + mock_gethostbyname.assert_called_once_with('foo.com') + mock_probe_sni.assert_called_once_with( + host='127.0.0.1', port=self.response.PORT, + name=self.z_domain) + + self.response.probe_cert('foo.com', host='8.8.8.8') + mock_probe_sni.assert_called_with( + host='8.8.8.8', port=mock.ANY, name=mock.ANY) + + self.response.probe_cert('foo.com', port=1234) + mock_probe_sni.assert_called_with( + host=mock.ANY, port=1234, name=mock.ANY) + + self.response.probe_cert('foo.com', bar='baz') + mock_probe_sni.assert_called_with( + host=mock.ANY, port=mock.ANY, name=mock.ANY, bar='baz') + + self.response.probe_cert('foo.com', name=b'xxx') + mock_probe_sni.assert_called_with( + host=mock.ANY, port=mock.ANY, + name=self.z_domain) + + def test_gen_verify_cert(self): + key1 = test_util.load_pyopenssl_private_key('rsa512_key.pem') + cert, key2 = self.response.gen_cert(key1) + self.assertEqual(key1, key2) + self.assertTrue(self.response.verify_cert(cert)) + + def test_gen_verify_cert_gen_key(self): + cert, key = self.response.gen_cert() + self.assertTrue(isinstance(key, OpenSSL.crypto.PKey)) + self.assertTrue(self.response.verify_cert(cert)) + + def test_verify_bad_cert(self): + self.assertFalse(self.response.verify_cert( + test_util.load_cert('cert.pem'))) + + def test_simple_verify_bad_key_authorization(self): + key2 = jose.JWKRSA.load(test_util.load_vector('rsa256_key.pem')) + self.response.simple_verify(self.chall, "local", key2.public_key()) + + @mock.patch('acme.challenges.TLSSNI01Response.verify_cert', autospec=True) + def test_simple_verify(self, mock_verify_cert): + mock_verify_cert.return_value = mock.sentinel.verification + self.assertEqual(mock.sentinel.verification, self.response.simple_verify( + self.chall, self.domain, KEY.public_key(), + cert=mock.sentinel.cert)) + mock_verify_cert.assert_called_once_with(self.response, mock.sentinel.cert) + + @mock.patch('acme.challenges.TLSSNI01Response.probe_cert') + def test_simple_verify_false_on_probe_error(self, mock_probe_cert): + mock_probe_cert.side_effect = errors.Error + self.assertFalse(self.response.simple_verify( + self.chall, self.domain, KEY.public_key())) + + +class TLSSNI01Test(unittest.TestCase): + + def setUp(self): + from acme.challenges import TLSSNI01 + self.msg = TLSSNI01( + token=jose.b64decode('a82d5ff8ef740d12881f6d3c2277ab2e')) + self.jmsg = { + 'type': 'tls-sni-01', + 'token': 'a82d5ff8ef740d12881f6d3c2277ab2e', + } + + def test_to_partial_json(self): + self.assertEqual(self.jmsg, self.msg.to_partial_json()) + + def test_from_json(self): + from acme.challenges import TLSSNI01 + self.assertEqual(self.msg, TLSSNI01.from_json(self.jmsg)) + + def test_from_json_hashable(self): + from acme.challenges import TLSSNI01 + hash(TLSSNI01.from_json(self.jmsg)) + + def test_from_json_invalid_token_length(self): + from acme.challenges import TLSSNI01 + self.jmsg['token'] = jose.encode_b64jose(b'abcd') + self.assertRaises( + jose.DeserializationError, TLSSNI01.from_json, self.jmsg) + + @mock.patch('acme.challenges.TLSSNI01Response.gen_cert') + def test_validation(self, mock_gen_cert): + mock_gen_cert.return_value = ('cert', 'key') + self.assertEqual(('cert', 'key'), self.msg.validation( + KEY, cert_key=mock.sentinel.cert_key)) + mock_gen_cert.assert_called_once_with(key=mock.sentinel.cert_key) + + class DVSNITest(unittest.TestCase): def setUp(self):