From 4407210e0114bcacb75567abb71dce27f908c96f Mon Sep 17 00:00:00 2001 From: Jakub Warmuz Date: Tue, 30 Jun 2015 13:15:11 +0000 Subject: [PATCH] Fix --no-verify-ssl in HEAD, refactor acme.client_tests. Fix #521 by introducing MissingNonceError, which by shows response headers when printed to STDOUT. More sensible solution (a'la #523) is blocked by boulder#417 (HTTP 405 response for HEAD). Split out ClientNetworkWithMockedResponseTest from ClientNetworkTest, which improves readability and makes it easier to test (less mocks). --- acme/client.py | 83 ++++++++++---------- acme/client_test.py | 181 +++++++++++++++++++++++++------------------- acme/errors.py | 40 +++++++++- acme/errors_test.py | 33 ++++++++ 4 files changed, 218 insertions(+), 119 deletions(-) create mode 100644 acme/errors_test.py diff --git a/acme/client.py b/acme/client.py index aac539974..4a4192528 100644 --- a/acme/client.py +++ b/acme/client.py @@ -506,24 +506,47 @@ class ClientNetwork(object): raise errors.ClientError( 'Unexpected response Content-Type: {0}'.format(response_ct)) - def get(self, uri, content_type=JSON_CONTENT_TYPE, **kwargs): - """Send GET request. + return response - :raises .ClientError: + def _send_request(self, method, url, *args, **kwargs): + """Send HTTP request. + + Makes sure that `verify_ssl` is respected. Logs request and + response (with headers). For allowed parameters please see + `requests.request`. + + :param str method: method for the new `requests.Request` object + :param str url: URL for the new `requests.Request` object + + :raises requests.exceptions.RequestException: in case of any problems :returns: HTTP Response :rtype: `requests.Response` + """ - logger.debug('Sending GET request to %s', uri) - kwargs.setdefault('verify', self.verify_ssl) - try: - response = requests.get(uri, **kwargs) - except requests.exceptions.RequestException as error: - raise errors.ClientError(error) - self._check_response(response, content_type=content_type) + logging.debug('Sending %s request to %s', method, url) + kwargs['verify'] = self.verify_ssl + response = requests.request(method, url, *args, **kwargs) + logging.debug('Received %s. Headers: %s. Content: %r', + response, response.headers, response.content) return response + def head(self, *args, **kwargs): + """Send HEAD request without checking the response. + + Note, that `_check_response` is not called, as it is expected + that status code other than successfuly 2xx will be returned, or + messages2.Error will be raised by the server. + + """ + return self._send_request('HEAD', *args, **kwargs) + + def get(self, url, content_type=JSON_CONTENT_TYPE, **kwargs): + """Send GET request and check response.""" + return self._check_response( + self._send_request('GET', url, **kwargs), content_type=content_type) + def _add_nonce(self, response): if self.REPLAY_NONCE_HEADER in response.headers: nonce = response.headers[self.REPLAY_NONCE_HEADER] @@ -532,39 +555,19 @@ class ClientNetwork(object): logger.debug('Storing nonce: %r', nonce) self._nonces.add(nonce) else: - raise errors.ClientError('Invalid nonce ({0}): {1}'.format( - nonce, error)) + raise errors.BadNonce(nonce, error) else: - raise errors.ClientError( - 'Server {0} response did not include a replay nonce'.format( - response.request.method)) + raise errors.MissingNonce(response) - def _get_nonce(self, uri): + def _get_nonce(self, url): if not self._nonces: - logger.debug('Requesting fresh nonce by sending HEAD to %s', uri) - self._add_nonce(requests.head(uri)) + logging.debug('Requesting fresh nonce') + self._add_nonce(self.head(url)) return self._nonces.pop() - def post(self, uri, obj, content_type=JSON_CONTENT_TYPE, **kwargs): - """Send POST data. - - :param JSONDeSerializable obj: Will be wrapped in JWS. - :param str content_type: Expected ``Content-Type``, fails if not set. - - :raises acme.messages.ClientError: - - :returns: HTTP Response - :rtype: `requests.Response` - - """ - data = self._wrap_in_jws(obj, self._get_nonce(uri)) - logger.debug('Sending POST data to %s: %s', uri, data) - kwargs.setdefault('verify', self.verify_ssl) - try: - response = requests.post(uri, data=data, **kwargs) - except requests.exceptions.RequestException as error: - raise errors.ClientError(error) - + def post(self, url, obj, content_type=JSON_CONTENT_TYPE, **kwargs): + """POST object wrapped in `.JWS` and check response.""" + data = self._wrap_in_jws(obj, self._get_nonce(url)) + response = self._send_request('POST', url, data=data, **kwargs) self._add_nonce(response) - self._check_response(response, content_type=content_type) - return response + return self._check_response(response, content_type=content_type) diff --git a/acme/client_test.py b/acme/client_test.py index e935d9563..b934e1efd 100644 --- a/acme/client_test.py +++ b/acme/client_test.py @@ -357,11 +357,6 @@ class ClientNetworkTest(unittest.TestCase): self.net = ClientNetwork( key=KEY, alg=jose.RS256, verify_ssl=self.verify_ssl) - self.nonce = jose.b64encode('Nonce') - # pylint: disable=protected-access - self.assertEqual(self.net._nonces, set()) - self.net._nonces.add(self.nonce) - self.response = mock.MagicMock(ok=True, status_code=httplib.OK) self.response.headers = {} self.response.links = {} @@ -422,97 +417,127 @@ class ClientNetworkTest(unittest.TestCase): self.response.json.side_effect = ValueError for response_ct in [self.net.JSON_CONTENT_TYPE, 'foo']: self.response.headers['Content-Type'] = response_ct - # pylint: disable=protected-access - self.net._check_response(self.response) + # pylint: disable=protected-access,no-value-for-parameter + self.assertEqual( + self.response, self.net._check_response(self.response)) def test_check_response_jobj(self): self.response.json.return_value = {} for response_ct in [self.net.JSON_CONTENT_TYPE, 'foo']: self.response.headers['Content-Type'] = response_ct + # pylint: disable=protected-access,no-value-for-parameter + self.assertEqual( + self.response, self.net._check_response(self.response)) + + @mock.patch('acme.client.requests') + def test_send_request(self, mock_requests): + mock_requests.request.return_value = self.response + # pylint: disable=protected-access + self.assertEqual(self.response, self.net._send_request( + 'HEAD', 'url', 'foo', bar='baz')) + mock_requests.request.assert_called_once_with( + 'HEAD', 'url', 'foo', verify=mock.ANY, bar='baz') + + @mock.patch('acme.client.requests') + def test_send_request_verify_ssl(self, mock_requests): + # pylint: disable=protected-access + for verify in True, False: + mock_requests.request.reset_mock() + mock_requests.request.return_value = self.response + self.net.verify_ssl = verify # pylint: disable=protected-access - self.net._check_response(self.response) + self.assertEqual( + self.response, self.net._send_request('GET', 'url')) + mock_requests.request.assert_called_once_with( + 'GET', 'url', verify=verify) @mock.patch('acme.client.requests') - def test_get_requests_error_passthrough(self, requests_mock): - requests_mock.exceptions = requests.exceptions - requests_mock.get.side_effect = requests.exceptions.RequestException - self.assertRaises(errors.ClientError, self.net.get, 'uri') - - @mock.patch('acme.client.requests') - def test_get(self, requests_mock): + def test_requests_error_passthrough(self, mock_requests): + mock_requests.exceptions = requests.exceptions + mock_requests.request.side_effect = requests.exceptions.RequestException # pylint: disable=protected-access - self.net._check_response = mock.MagicMock() - self.net.get('uri', content_type='ct') - self.net._check_response.assert_called_once_with( - requests_mock.get('uri'), content_type='ct') + self.assertRaises(requests.exceptions.RequestException, + self.net._send_request, 'GET', 'uri') + + +class ClientNetworkWithMockedResponseTest(unittest.TestCase): + """Tests for acme.client.ClientNetwork which mock out response.""" + # pylint: disable=too-many-instance-attributes + + def setUp(self): + from acme.client import ClientNetwork + self.net = ClientNetwork(key=None, alg=None) + + self.response = mock.MagicMock(ok=True, status_code=httplib.OK) + self.response.headers = {} + self.response.links = {} + self.checked_response = mock.MagicMock() + self.obj = mock.MagicMock() + self.wrapped_obj = mock.MagicMock() + self.content_type = mock.sentinel.content_type + + self.all_nonces = [jose.b64encode('Nonce'), jose.b64encode('Nonce2')] + self.available_nonces = self.all_nonces[:] + def send_request(*args, **kwargs): + # pylint: disable=unused-argument,missing-docstring + if self.available_nonces: + self.response.headers = { + self.net.REPLAY_NONCE_HEADER: self.available_nonces.pop()} + else: + self.response.headers = {} + return self.response - def _mock_wrap_in_jws(self): # pylint: disable=protected-access - self.net._wrap_in_jws = self.wrap_in_jws + self.net._send_request = self.send_request = mock.MagicMock( + side_effect=send_request) + self.net._check_response = self.check_response + self.net._wrap_in_jws = mock.MagicMock(return_value=self.wrapped_obj) - @mock.patch('acme.client.requests') - def test_post_requests_error_passthrough(self, requests_mock): - requests_mock.exceptions = requests.exceptions - requests_mock.post.side_effect = requests.exceptions.RequestException - self._mock_wrap_in_jws() - self.assertRaises( - errors.ClientError, self.net.post, 'uri', mock.sentinel.obj) + def check_response(self, response, content_type): + # pylint: disable=missing-docstring + self.assertEqual(self.response, response) + self.assertEqual(self.content_type, content_type) + return self.checked_response - @mock.patch('acme.client.requests') - def test_post(self, requests_mock): + def test_head(self): + self.assertEqual(self.response, self.net.head('url', 'foo', bar='baz')) + self.send_request.assert_called_once('HEAD', 'url', 'foo', bar='baz') + + def test_get(self): + self.assertEqual(self.checked_response, self.net.get( + 'url', content_type=self.content_type, bar='baz')) + self.send_request.assert_called_once_with('GET', 'url', bar='baz') + + def test_post(self): # pylint: disable=protected-access - self.net._check_response = mock.MagicMock() - self._mock_wrap_in_jws() - requests_mock.post().headers = { - self.net.REPLAY_NONCE_HEADER: self.nonce} - self.net.post('uri', mock.sentinel.obj, content_type='ct') - self.net._check_response.assert_called_once_with( - requests_mock.post('uri', mock.sentinel.wrapped), content_type='ct') + self.assertEqual(self.checked_response, self.net.post( + 'uri', self.obj, content_type=self.content_type)) + self.net._wrap_in_jws.assert_called_once_with( + self.obj, self.all_nonces.pop()) - @mock.patch('acme.client.requests') - def test_post_replay_nonce_handling(self, requests_mock): - # pylint: disable=protected-access - self.net._check_response = mock.MagicMock() - self._mock_wrap_in_jws() + assert not self.available_nonces + self.assertRaises(errors.MissingNonce, self.net.post, + 'uri', self.obj, content_type=self.content_type) + self.net._wrap_in_jws.assert_called_with( + self.obj, self.all_nonces.pop()) - self.net._nonces.clear() - self.assertRaises( - errors.ClientError, self.net.post, 'uri', mock.sentinel.obj) + def test_post_wrong_initial_nonce(self): # HEAD + self.available_nonces = ['f', jose.b64encode('good')] + self.assertRaises(errors.BadNonce, self.net.post, 'uri', + self.obj, content_type=self.content_type) - nonce2 = jose.b64encode('Nonce2') - requests_mock.head('uri').headers = { - self.net.REPLAY_NONCE_HEADER: nonce2} - requests_mock.post('uri').headers = { - self.net.REPLAY_NONCE_HEADER: self.nonce} + def test_post_wrong_post_response_nonce(self): + self.available_nonces = [jose.b64encode('good'), 'f'] + self.assertRaises(errors.BadNonce, self.net.post, 'uri', + self.obj, content_type=self.content_type) - self.net.post('uri', mock.sentinel.obj) - - requests_mock.head.assert_called_with('uri') - self.wrap_in_jws.assert_called_once_with(mock.sentinel.obj, nonce2) - self.assertEqual(self.net._nonces, set([self.nonce])) - - # wrong nonce - requests_mock.post('uri').headers = {self.net.REPLAY_NONCE_HEADER: 'F'} - self.assertRaises( - errors.ClientError, self.net.post, 'uri', mock.sentinel.obj) - - @mock.patch('acme.client.requests') - def test_get_post_verify_ssl(self, requests_mock): - # pylint: disable=protected-access - self._mock_wrap_in_jws() - self.net._check_response = mock.MagicMock() - - for verify_ssl in [True, False]: - self.net.verify_ssl = verify_ssl - self.net.get('uri') - self.net._nonces.add('N') - requests_mock.post().headers = { - self.net.REPLAY_NONCE_HEADER: self.nonce} - self.net.post('uri', mock.sentinel.obj) - requests_mock.get.assert_called_once_with('uri', verify=verify_ssl) - requests_mock.post.assert_called_with( - 'uri', data=mock.sentinel.wrapped, verify=verify_ssl) - requests_mock.reset_mock() + def test_head_get_post_error_passthrough(self): + self.send_request.side_effect = requests.exceptions.RequestException + for method in self.net.head, self.net.get: + self.assertRaises( + requests.exceptions.RequestException, method, 'GET', 'uri') + self.assertRaises(requests.exceptions.RequestException, + self.net.post, 'uri', obj=self.obj) if __name__ == '__main__': diff --git a/acme/errors.py b/acme/errors.py index 5046d7aee..9a96ec43a 100644 --- a/acme/errors.py +++ b/acme/errors.py @@ -5,11 +5,49 @@ from acme.jose import errors as jose_errors class Error(Exception): """Generic ACME error.""" + class SchemaValidationError(jose_errors.DeserializationError): """JSON schema ACME object validation error.""" + class ClientError(Error): """Network error.""" + class UnexpectedUpdate(ClientError): - """Unexpected update.""" + """Unexpected update error.""" + + +class NonceError(ClientError): + """Server response nonce error.""" + + +class BadNonce(NonceError): + """Bad nonce error.""" + def __init__(self, nonce, error, *args, **kwargs): + super(BadNonce, self).__init__(*args, **kwargs) + self.nonce = nonce + self.error = error + + def __str__(self): + return 'Invalid nonce ({0!r}): {1}'.format(self.nonce, self.error) + + +class MissingNonce(NonceError): + """Missing nonce error. + + According to the specification an "ACME server MUST include an + Replay-Nonce header field in each successful response to a POST it + provides to a client (...)". + + :ivar requests.Response response: HTTP Response + + """ + def __init__(self, response, *args, **kwargs): + super(MissingNonce, self).__init__(*args, **kwargs) + self.response = response + + def __str__(self): + return ('Server {0} response did not include a replay ' + 'nonce, headers: {1}'.format( + self.response.request.method, self.response.headers)) diff --git a/acme/errors_test.py b/acme/errors_test.py new file mode 100644 index 000000000..3790d91ed --- /dev/null +++ b/acme/errors_test.py @@ -0,0 +1,33 @@ +"""Tests for acme.errors.""" +import unittest + +import mock + + +class BadNonceTest(unittest.TestCase): + """Tests for acme.errors.BadNonce.""" + + def setUp(self): + from acme.errors import BadNonce + self.error = BadNonce(nonce="xxx", error="error") + + def test_str(self): + self.assertEqual("Invalid nonce ('xxx'): error", str(self.error)) + + +class MissingNonceTest(unittest.TestCase): + """Tests for acme.errors.MissingNonce.""" + + def setUp(self): + from acme.errors import MissingNonce + self.response = mock.MagicMock(headers={}) + self.response.request.method = 'FOO' + self.error = MissingNonce(self.response) + + def test_str(self): + self.assertTrue("FOO" in str(self.error)) + self.assertTrue("{}" in str(self.error)) + + +if __name__ == "__main__": + unittest.main() # pragma: no cover