diff --git a/acme/acme/jose/__init__.py b/acme/acme/jose/__init__.py index 793b342b0..76969139b 100644 --- a/acme/acme/jose/__init__.py +++ b/acme/acme/jose/__init__.py @@ -44,8 +44,10 @@ from acme.jose.json_util import ( decode_cert, decode_csr, decode_hex16, + encode_b64jose, encode_cert, encode_csr, + encode_hex16, ) from acme.jose.jwa import ( diff --git a/acme/acme/jose/b64.py b/acme/acme/jose/b64.py index 8f2d284ce..5fccdce2e 100644 --- a/acme/acme/jose/b64.py +++ b/acme/acme/jose/b64.py @@ -9,28 +9,31 @@ .. _`JOSE Base64`: https://tools.ietf.org/html/draft-ietf-jose-json-web-signature-37#appendix-C -.. warning:: Do NOT try to call this module "base64", - as it will "shadow" the standard library. +.. Do NOT try to call this module "base64", as it will "shadow" the + standard library. """ import base64 +import six + def b64encode(data): """JOSE Base64 encode. :param data: Data to be encoded. - :type data: str or bytearray + :type data: `bytes` or `bytearray` :returns: JOSE Base64 string. - :rtype: str + :rtype: bytes :raises TypeError: if `data` is of incorrect type """ - if not isinstance(data, str): - raise TypeError('argument should be str or bytearray') - return base64.urlsafe_b64encode(data).rstrip('=') + if not isinstance(data, (six.binary_type, bytearray)): + raise TypeError('argument should be {0} or bytearray'.format( + six.binary_type)) + return base64.urlsafe_b64encode(data).rstrip(b'=') def b64decode(data): @@ -38,21 +41,22 @@ def b64decode(data): :param data: Base64 string to be decoded. If it's unicode, then only ASCII characters are allowed. - :type data: str or unicode + :type data: `bytes` or `unicode` :returns: Decoded data. + :rtype: bytes :raises TypeError: if input is of incorrect type :raises ValueError: if input is unicode with non-ASCII characters """ - if isinstance(data, unicode): + if isinstance(data, six.string_types): try: data = data.encode('ascii') except UnicodeEncodeError: raise ValueError( 'unicode argument should contain only ASCII characters') - elif not isinstance(data, str): + elif not isinstance(data, six.binary_type): raise TypeError('argument should be a str or unicode') - return base64.urlsafe_b64decode(data + '=' * (4 - (len(data) % 4))) + return base64.urlsafe_b64decode(data + b'=' * (4 - (len(data) % 4))) diff --git a/acme/acme/jose/b64_test.py b/acme/acme/jose/b64_test.py index 0c243cb2a..989f8e7fe 100644 --- a/acme/acme/jose/b64_test.py +++ b/acme/acme/jose/b64_test.py @@ -1,20 +1,22 @@ """Tests for acme.jose.b64.""" import unittest +import six + # https://en.wikipedia.org/wiki/Base64#Examples B64_PADDING_EXAMPLES = { - 'any carnal pleasure.': ('YW55IGNhcm5hbCBwbGVhc3VyZS4', '='), - 'any carnal pleasure': ('YW55IGNhcm5hbCBwbGVhc3VyZQ', '=='), - 'any carnal pleasur': ('YW55IGNhcm5hbCBwbGVhc3Vy', ''), - 'any carnal pleasu': ('YW55IGNhcm5hbCBwbGVhc3U', '='), - 'any carnal pleas': ('YW55IGNhcm5hbCBwbGVhcw', '=='), + b'any carnal pleasure.': (b'YW55IGNhcm5hbCBwbGVhc3VyZS4', b'='), + b'any carnal pleasure': (b'YW55IGNhcm5hbCBwbGVhc3VyZQ', b'=='), + b'any carnal pleasur': (b'YW55IGNhcm5hbCBwbGVhc3Vy', b''), + b'any carnal pleasu': (b'YW55IGNhcm5hbCBwbGVhc3U', b'='), + b'any carnal pleas': (b'YW55IGNhcm5hbCBwbGVhcw', b'=='), } B64_URL_UNSAFE_EXAMPLES = { - chr(251) + chr(239): '--8', - chr(255) * 2: '__8', + six.int2byte(251) + six.int2byte(239): b'--8', + six.int2byte(255) * 2: b'__8', } @@ -26,14 +28,20 @@ class B64EncodeTest(unittest.TestCase): from acme.jose.b64 import b64encode return b64encode(data) + def test_empty(self): + self.assertEqual(self._call(b''), b'') + def test_unsafe_url(self): - for text, b64 in B64_URL_UNSAFE_EXAMPLES.iteritems(): + for text, b64 in six.iteritems(B64_URL_UNSAFE_EXAMPLES): self.assertEqual(self._call(text), b64) def test_different_paddings(self): - for text, (b64, _) in B64_PADDING_EXAMPLES.iteritems(): + for text, (b64, _) in six.iteritems(B64_PADDING_EXAMPLES): self.assertEqual(self._call(text), b64) + def test_bytearray_ok(self): + self.assertEqual(self._call(bytearray(b'foo')), b'Zm9v') + def test_unicode_fails_with_type_error(self): self.assertRaises(TypeError, self._call, u'some unicode') @@ -47,24 +55,24 @@ class B64DecodeTest(unittest.TestCase): return b64decode(data) def test_unsafe_url(self): - for text, b64 in B64_URL_UNSAFE_EXAMPLES.iteritems(): + for text, b64 in six.iteritems(B64_URL_UNSAFE_EXAMPLES): self.assertEqual(self._call(b64), text) def test_input_without_padding(self): - for text, (b64, _) in B64_PADDING_EXAMPLES.iteritems(): + for text, (b64, _) in six.iteritems(B64_PADDING_EXAMPLES): self.assertEqual(self._call(b64), text) def test_input_with_padding(self): - for text, (b64, pad) in B64_PADDING_EXAMPLES.iteritems(): + for text, (b64, pad) in six.iteritems(B64_PADDING_EXAMPLES): self.assertEqual(self._call(b64 + pad), text) def test_unicode_with_ascii(self): - self.assertEqual(self._call(u'YQ'), 'a') + self.assertEqual(self._call(u'YQ'), b'a') def test_non_ascii_unicode_fails(self): self.assertRaises(ValueError, self._call, u'\u0105') - def test_type_error_no_unicode_or_str(self): + def test_type_error_no_unicode_or_bytes(self): self.assertRaises(TypeError, self._call, object()) diff --git a/acme/acme/jose/interfaces.py b/acme/acme/jose/interfaces.py index 27dcf863f..96dae6bae 100644 --- a/acme/acme/jose/interfaces.py +++ b/acme/acme/jose/interfaces.py @@ -3,12 +3,15 @@ import abc import collections import json +import six + from acme.jose import util # pylint: disable=no-self-argument,no-method-argument,no-init,inherit-non-class # pylint: disable=too-few-public-methods +@six.add_metaclass(abc.ABCMeta) class JSONDeSerializable(object): # pylint: disable=too-few-public-methods """Interface for (de)serializable JSON objects. @@ -96,7 +99,6 @@ class JSONDeSerializable(object): return Bar() """ - __metaclass__ = abc.ABCMeta @abc.abstractmethod def to_partial_json(self): # pragma: no cover @@ -133,7 +135,7 @@ class JSONDeSerializable(object): def _serialize(obj): if isinstance(obj, JSONDeSerializable): return _serialize(obj.to_partial_json()) - if isinstance(obj, basestring): # strings are sequence + if isinstance(obj, six.string_types): # strings are Sequence return obj elif isinstance(obj, list): return [_serialize(subobj) for subobj in obj] @@ -143,14 +145,14 @@ class JSONDeSerializable(object): return tuple(_serialize(subobj) for subobj in obj) elif isinstance(obj, collections.Mapping): return dict((_serialize(key), _serialize(value)) - for key, value in obj.iteritems()) + for key, value in six.iteritems(obj)) else: return obj return _serialize(self) @util.abstractclassmethod - def from_json(cls, unused_jobj): + def from_json(cls, jobj): # pylint: disable=unused-argument """Deserialize a decoded JSON document. :param jobj: Python object, composed of only other basic data @@ -182,7 +184,11 @@ class JSONDeSerializable(object): return json.dumps(self, default=self.json_dump_default, **kwargs) def json_dumps_pretty(self): - """Dump the object to pretty JSON document string.""" + """Dump the object to pretty JSON document string. + + :rtype: str + + """ return self.json_dumps(sort_keys=True, indent=4, separators=(',', ': ')) @classmethod @@ -190,7 +196,7 @@ class JSONDeSerializable(object): """Serialize Python object. This function is meant to be passed as ``default`` to - :func:`json.load` or :func:`json.loads`. They call + :func:`json.dump` or :func:`json.dumps`. They call ``default(python_object)`` only for non-basic Python types, so this function necessarily raises :class:`TypeError` if ``python_object`` is not an instance of diff --git a/acme/acme/jose/json_util.py b/acme/acme/jose/json_util.py index fe3831296..c531efd9d 100644 --- a/acme/acme/jose/json_util.py +++ b/acme/acme/jose/json_util.py @@ -11,6 +11,7 @@ import binascii import logging import OpenSSL +import six from acme.jose import b64 from acme.jose import errors @@ -109,7 +110,7 @@ class Field(object): elif isinstance(value, dict): return util.frozendict( dict((cls.default_decoder(key), cls.default_decoder(value)) - for key, value in value.iteritems())) + for key, value in six.iteritems(value))) else: # integer or string return value @@ -167,17 +168,20 @@ class JSONObjectWithFieldsMeta(abc.ABCMeta): for base in bases: fields.update(getattr(base, '_fields', {})) # Do not reorder, this class might override fields from base classes! - for key, value in dikt.items(): # not iterkeys() (in-place edit!) + for key, value in tuple(six.iteritems(dikt)): + # not six.iterkeys() (in-place edit!) if isinstance(value, Field): fields[key] = dikt.pop(key) dikt['_orig_slots'] = dikt.get('__slots__', ()) - dikt['__slots__'] = tuple(list(dikt['_orig_slots']) + fields.keys()) + dikt['__slots__'] = tuple( + list(dikt['_orig_slots']) + list(six.iterkeys(fields))) dikt['_fields'] = fields return abc.ABCMeta.__new__(mcs, name, bases, dikt) +@six.add_metaclass(JSONObjectWithFieldsMeta) class JSONObjectWithFields(util.ImmutableMap, interfaces.JSONDeSerializable): # pylint: disable=too-few-public-methods """JSON object with fields. @@ -205,13 +209,12 @@ class JSONObjectWithFields(util.ImmutableMap, interfaces.JSONDeSerializable): assert Foo(bar='baz').bar == 'baz' """ - __metaclass__ = JSONObjectWithFieldsMeta @classmethod def _defaults(cls): """Get default fields values.""" return dict([(slot, field.default) for slot, field - in cls._fields.iteritems() if field.omitempty]) + in six.iteritems(cls._fields) if field.omitempty]) def __init__(self, **kwargs): # pylint: disable=star-args @@ -222,7 +225,7 @@ class JSONObjectWithFields(util.ImmutableMap, interfaces.JSONDeSerializable): """Serialize fields to JSON.""" jobj = {} omitted = set() - for slot, field in self._fields.iteritems(): + for slot, field in six.iteritems(self._fields): value = getattr(self, slot) if field.omit(value): @@ -246,7 +249,7 @@ class JSONObjectWithFields(util.ImmutableMap, interfaces.JSONDeSerializable): @classmethod def _check_required(cls, jobj): missing = set() - for _, field in cls._fields.iteritems(): + for _, field in six.iteritems(cls._fields): if not field.omitempty and field.json_name not in jobj: missing.add(field.json_name) @@ -260,7 +263,7 @@ class JSONObjectWithFields(util.ImmutableMap, interfaces.JSONDeSerializable): """Deserialize fields from JSON.""" cls._check_required(jobj) fields = {} - for slot, field in cls._fields.iteritems(): + for slot, field in six.iteritems(cls._fields): if field.json_name not in jobj and field.omitempty: fields[slot] = field.default else: @@ -278,17 +281,31 @@ class JSONObjectWithFields(util.ImmutableMap, interfaces.JSONDeSerializable): return cls(**cls.fields_from_json(jobj)) +def encode_b64jose(data): + """Encode JOSE Base-64 field. + + :param bytes data: + :rtype: `unicode` + + """ + # b64encode produces ASCII characters only + return b64.b64encode(data).decode('ascii') + def decode_b64jose(data, size=None, minimum=False): """Decode JOSE Base-64 field. + :param unicode data: :param int size: Required length (after decoding). :param bool minimum: If ``True``, then `size` will be treated as minimum required length, as opposed to exact equality. + :rtype: bytes + """ + error_cls = TypeError if six.PY2 else binascii.Error try: - decoded = b64.b64decode(data) - except TypeError as error: + decoded = b64.b64decode(data.encode()) + except error_cls as error: raise errors.DeserializationError(error) if size is not None and ((not minimum and len(decoded) != size) @@ -297,35 +314,53 @@ def decode_b64jose(data, size=None, minimum=False): return decoded +def encode_hex16(value): + """Hexlify. + + :param bytes value: + :rtype: unicode + + """ + return binascii.hexlify(value).decode() def decode_hex16(value, size=None, minimum=False): """Decode hexlified field. + :param unicode value: :param int size: Required length (after decoding). :param bool minimum: If ``True``, then `size` will be treated as minimum required length, as opposed to exact equality. + :rtype: bytes + """ + value = value.encode() if size is not None and ((not minimum and len(value) != size * 2) or (minimum and len(value) < size * 2)): raise errors.DeserializationError() + error_cls = TypeError if six.PY2 else binascii.Error try: return binascii.unhexlify(value) - except TypeError as error: + except error_cls as error: raise errors.DeserializationError(error) def encode_cert(cert): """Encode certificate as JOSE Base-64 DER. - :param cert: Certificate. - :type cert: :class:`acme.jose.util.ComparableX509` + :type cert: `OpenSSL.crypto.X509` wrapped in `.ComparableX509` + :rtype: unicode """ - return b64.b64encode(OpenSSL.crypto.dump_certificate( + return encode_b64jose(OpenSSL.crypto.dump_certificate( OpenSSL.crypto.FILETYPE_ASN1, cert)) def decode_cert(b64der): - """Decode JOSE Base-64 DER-encoded certificate.""" + """Decode JOSE Base-64 DER-encoded certificate. + + :param unicode b64der: + :rtype: `OpenSSL.crypto.X509` wrapped in `.ComparableX509` + + """ try: return util.ComparableX509(OpenSSL.crypto.load_certificate( OpenSSL.crypto.FILETYPE_ASN1, decode_b64jose(b64der))) @@ -333,12 +368,22 @@ def decode_cert(b64der): raise errors.DeserializationError(error) def encode_csr(csr): - """Encode CSR as JOSE Base-64 DER.""" - return b64.b64encode(OpenSSL.crypto.dump_certificate_request( + """Encode CSR as JOSE Base-64 DER. + + :type csr: `OpenSSL.crypto.X509Req` wrapped in `.ComparableX509` + :rtype: unicode + + """ + return encode_b64jose(OpenSSL.crypto.dump_certificate_request( OpenSSL.crypto.FILETYPE_ASN1, csr)) def decode_csr(b64der): - """Decode JOSE Base-64 DER-encoded CSR.""" + """Decode JOSE Base-64 DER-encoded CSR. + + :param unicode b64der: + :rtype: `OpenSSL.crypto.X509Req` wrapped in `.ComparableX509` + + """ try: return util.ComparableX509(OpenSSL.crypto.load_certificate_request( OpenSSL.crypto.FILETYPE_ASN1, decode_b64jose(b64der))) @@ -372,7 +417,7 @@ class TypedJSONObjectWithFields(JSONObjectWithFields): @classmethod def get_type_cls(cls, jobj): """Get the registered class for ``jobj``.""" - if cls in cls.TYPES.itervalues(): + if cls in six.itervalues(cls.TYPES): assert jobj[cls.type_field_name] # cls is already registered type_cls, force to use it # so that, e.g Revocation.from_json(jobj) fails if diff --git a/acme/acme/jose/json_util_test.py b/acme/acme/jose/json_util_test.py index 9e2a87858..2225267ee 100644 --- a/acme/acme/jose/json_util_test.py +++ b/acme/acme/jose/json_util_test.py @@ -3,6 +3,7 @@ import itertools import unittest import mock +import six from acme import test_util @@ -92,8 +93,8 @@ class JSONObjectWithFieldsMetaTest(unittest.TestCase): self.field2 = Field('Baz2') # pylint: disable=invalid-name,missing-docstring,too-few-public-methods # pylint: disable=blacklisted-name + @six.add_metaclass(JSONObjectWithFieldsMeta) class A(object): - __metaclass__ = JSONObjectWithFieldsMeta __slots__ = ('bar',) baz = self.field class B(A): @@ -207,62 +208,82 @@ class JSONObjectWithFieldsTest(unittest.TestCase): class DeEncodersTest(unittest.TestCase): def setUp(self): self.b64_cert = ( - 'MIIB3jCCAYigAwIBAgICBTkwDQYJKoZIhvcNAQELBQAwdzELMAkGA1UEBhM' - 'CVVMxETAPBgNVBAgMCE1pY2hpZ2FuMRIwEAYDVQQHDAlBbm4gQXJib3IxKz' - 'ApBgNVBAoMIlVuaXZlcnNpdHkgb2YgTWljaGlnYW4gYW5kIHRoZSBFRkYxF' - 'DASBgNVBAMMC2V4YW1wbGUuY29tMB4XDTE0MTIxMTIyMzQ0NVoXDTE0MTIx' - 'ODIyMzQ0NVowdzELMAkGA1UEBhMCVVMxETAPBgNVBAgMCE1pY2hpZ2FuMRI' - 'wEAYDVQQHDAlBbm4gQXJib3IxKzApBgNVBAoMIlVuaXZlcnNpdHkgb2YgTW' - 'ljaGlnYW4gYW5kIHRoZSBFRkYxFDASBgNVBAMMC2V4YW1wbGUuY29tMFwwD' - 'QYJKoZIhvcNAQEBBQADSwAwSAJBAKx1c7RR7R_drnBSQ_zfx1vQLHUbFLh1' - 'AQQQ5R8DZUXd36efNK79vukFhN9HFoHZiUvOjm0c-pVE6K-EdE_twuUCAwE' - 'AATANBgkqhkiG9w0BAQsFAANBAC24z0IdwIVKSlntksllvr6zJepBH5fMnd' - 'fk3XJp10jT6VE-14KNtjh02a56GoraAvJAT5_H67E8GvJ_ocNnB_o' + u'MIIB3jCCAYigAwIBAgICBTkwDQYJKoZIhvcNAQELBQAwdzELMAkGA1UEBhM' + u'CVVMxETAPBgNVBAgMCE1pY2hpZ2FuMRIwEAYDVQQHDAlBbm4gQXJib3IxKz' + u'ApBgNVBAoMIlVuaXZlcnNpdHkgb2YgTWljaGlnYW4gYW5kIHRoZSBFRkYxF' + u'DASBgNVBAMMC2V4YW1wbGUuY29tMB4XDTE0MTIxMTIyMzQ0NVoXDTE0MTIx' + u'ODIyMzQ0NVowdzELMAkGA1UEBhMCVVMxETAPBgNVBAgMCE1pY2hpZ2FuMRI' + u'wEAYDVQQHDAlBbm4gQXJib3IxKzApBgNVBAoMIlVuaXZlcnNpdHkgb2YgTW' + u'ljaGlnYW4gYW5kIHRoZSBFRkYxFDASBgNVBAMMC2V4YW1wbGUuY29tMFwwD' + u'QYJKoZIhvcNAQEBBQADSwAwSAJBAKx1c7RR7R_drnBSQ_zfx1vQLHUbFLh1' + u'AQQQ5R8DZUXd36efNK79vukFhN9HFoHZiUvOjm0c-pVE6K-EdE_twuUCAwE' + u'AATANBgkqhkiG9w0BAQsFAANBAC24z0IdwIVKSlntksllvr6zJepBH5fMnd' + u'fk3XJp10jT6VE-14KNtjh02a56GoraAvJAT5_H67E8GvJ_ocNnB_o' ) self.b64_csr = ( - 'MIIBXTCCAQcCAQAweTELMAkGA1UEBhMCVVMxETAPBgNVBAgMCE1pY2hpZ2F' - 'uMRIwEAYDVQQHDAlBbm4gQXJib3IxDDAKBgNVBAoMA0VGRjEfMB0GA1UECw' - 'wWVW5pdmVyc2l0eSBvZiBNaWNoaWdhbjEUMBIGA1UEAwwLZXhhbXBsZS5jb' - '20wXDANBgkqhkiG9w0BAQEFAANLADBIAkEArHVztFHtH92ucFJD_N_HW9As' - 'dRsUuHUBBBDlHwNlRd3fp580rv2-6QWE30cWgdmJS86ObRz6lUTor4R0T-3' - 'C5QIDAQABoCkwJwYJKoZIhvcNAQkOMRowGDAWBgNVHREEDzANggtleGFtcG' - 'xlLmNvbTANBgkqhkiG9w0BAQsFAANBAHJH_O6BtC9aGzEVCMGOZ7z9iIRHW' - 'Szr9x_bOzn7hLwsbXPAgO1QxEwL-X-4g20Gn9XBE1N9W6HCIEut2d8wACg' + u'MIIBXTCCAQcCAQAweTELMAkGA1UEBhMCVVMxETAPBgNVBAgMCE1pY2hpZ2F' + u'uMRIwEAYDVQQHDAlBbm4gQXJib3IxDDAKBgNVBAoMA0VGRjEfMB0GA1UECw' + u'wWVW5pdmVyc2l0eSBvZiBNaWNoaWdhbjEUMBIGA1UEAwwLZXhhbXBsZS5jb' + u'20wXDANBgkqhkiG9w0BAQEFAANLADBIAkEArHVztFHtH92ucFJD_N_HW9As' + u'dRsUuHUBBBDlHwNlRd3fp580rv2-6QWE30cWgdmJS86ObRz6lUTor4R0T-3' + u'C5QIDAQABoCkwJwYJKoZIhvcNAQkOMRowGDAWBgNVHREEDzANggtleGFtcG' + u'xlLmNvbTANBgkqhkiG9w0BAQsFAANBAHJH_O6BtC9aGzEVCMGOZ7z9iIRHW' + u'Szr9x_bOzn7hLwsbXPAgO1QxEwL-X-4g20Gn9XBE1N9W6HCIEut2d8wACg' ) - def test_decode_b64_jose_padding_error(self): - from acme.jose.json_util import decode_b64jose - self.assertRaises(errors.DeserializationError, decode_b64jose, 'x') + def test_encode_b64jose(self): + from acme.jose.json_util import encode_b64jose + encoded = encode_b64jose(b'x') + self.assertTrue(isinstance(encoded, six.string_types)) + self.assertEqual(u'eA', encoded) - def test_decode_b64_jose_size(self): + def test_decode_b64jose(self): from acme.jose.json_util import decode_b64jose - self.assertEqual('foo', decode_b64jose('Zm9v', size=3)) - self.assertRaises( - errors.DeserializationError, decode_b64jose, 'Zm9v', size=2) - self.assertRaises( - errors.DeserializationError, decode_b64jose, 'Zm9v', size=4) + decoded = decode_b64jose(u'eA') + self.assertTrue(isinstance(decoded, six.binary_type)) + self.assertEqual(b'x', decoded) - def test_decode_b64_jose_minimum_size(self): + def test_decode_b64jose_padding_error(self): from acme.jose.json_util import decode_b64jose - self.assertEqual('foo', decode_b64jose('Zm9v', size=3, minimum=True)) - self.assertEqual('foo', decode_b64jose('Zm9v', size=2, minimum=True)) + self.assertRaises(errors.DeserializationError, decode_b64jose, u'x') + + def test_decode_b64jose_size(self): + from acme.jose.json_util import decode_b64jose + self.assertEqual(b'foo', decode_b64jose(u'Zm9v', size=3)) + self.assertRaises( + errors.DeserializationError, decode_b64jose, u'Zm9v', size=2) + self.assertRaises( + errors.DeserializationError, decode_b64jose, u'Zm9v', size=4) + + def test_decode_b64jose_minimum_size(self): + from acme.jose.json_util import decode_b64jose + self.assertEqual(b'foo', decode_b64jose(u'Zm9v', size=3, minimum=True)) + self.assertEqual(b'foo', decode_b64jose(u'Zm9v', size=2, minimum=True)) self.assertRaises(errors.DeserializationError, decode_b64jose, - 'Zm9v', size=4, minimum=True) + u'Zm9v', size=4, minimum=True) + + def test_encode_hex16(self): + from acme.jose.json_util import encode_hex16 + encoded = encode_hex16(b'foo') + self.assertEqual(u'666f6f', encoded) + self.assertTrue(isinstance(encoded, six.string_types)) def test_decode_hex16(self): from acme.jose.json_util import decode_hex16 - self.assertEqual('foo', decode_hex16('666f6f')) + decoded = decode_hex16(u'666f6f') + self.assertEqual(b'foo', decoded) + self.assertTrue(isinstance(decoded, six.binary_type)) def test_decode_hex16_minimum_size(self): from acme.jose.json_util import decode_hex16 - self.assertEqual('foo', decode_hex16('666f6f', size=3, minimum=True)) - self.assertEqual('foo', decode_hex16('666f6f', size=2, minimum=True)) + self.assertEqual(b'foo', decode_hex16(u'666f6f', size=3, minimum=True)) + self.assertEqual(b'foo', decode_hex16(u'666f6f', size=2, minimum=True)) self.assertRaises(errors.DeserializationError, decode_hex16, - '666f6f', size=4, minimum=True) + u'666f6f', size=4, minimum=True) def test_decode_hex16_odd_length(self): from acme.jose.json_util import decode_hex16 - self.assertRaises(errors.DeserializationError, decode_hex16, 'x') + self.assertRaises(errors.DeserializationError, decode_hex16, u'x') def test_encode_cert(self): from acme.jose.json_util import encode_cert @@ -273,7 +294,7 @@ class DeEncodersTest(unittest.TestCase): cert = decode_cert(self.b64_cert) self.assertTrue(isinstance(cert, util.ComparableX509)) self.assertEqual(cert, CERT) - self.assertRaises(errors.DeserializationError, decode_cert, '') + self.assertRaises(errors.DeserializationError, decode_cert, u'') def test_encode_csr(self): from acme.jose.json_util import encode_csr @@ -284,7 +305,7 @@ class DeEncodersTest(unittest.TestCase): csr = decode_csr(self.b64_csr) self.assertTrue(isinstance(csr, util.ComparableX509)) self.assertEqual(csr, CSR) - self.assertRaises(errors.DeserializationError, decode_csr, '') + self.assertRaises(errors.DeserializationError, decode_csr, u'') class TypedJSONObjectWithFieldsTest(unittest.TestCase): diff --git a/acme/acme/jose/jwa.py b/acme/acme/jose/jwa.py index f081aa169..0c84905df 100644 --- a/acme/acme/jose/jwa.py +++ b/acme/acme/jose/jwa.py @@ -4,6 +4,7 @@ https://tools.ietf.org/html/draft-ietf-jose-json-web-algorithms-40 """ import abc +import collections import logging import cryptography.exceptions @@ -27,7 +28,7 @@ class JWA(interfaces.JSONDeSerializable): # pylint: disable=abstract-method """JSON Web Algorithm.""" -class JWASignature(JWA): +class JWASignature(JWA, collections.Hashable): """JSON Web Signature Algorithm.""" SIGNATURES = {} @@ -39,6 +40,9 @@ class JWASignature(JWA): return NotImplemented return self.name == other.name + def __hash__(self): + return hash((self.__class__, self.name)) + def __ne__(self, other): return not self == other diff --git a/acme/acme/jose/jwa_test.py b/acme/acme/jose/jwa_test.py index 1a3896f4a..3328d083a 100644 --- a/acme/acme/jose/jwa_test.py +++ b/acme/acme/jose/jwa_test.py @@ -58,12 +58,12 @@ class JWAHSTest(unittest.TestCase): # pylint: disable=too-few-public-methods def test_it(self): from acme.jose.jwa import HS256 sig = ( - "\xceR\xea\xcd\x94\xab\xcf\xfb\xe0\xacA.:\x1a'\x08i\xe2\xc4" - "\r\x85+\x0e\x85\xaeUZ\xd4\xb3\x97zO" + b"\xceR\xea\xcd\x94\xab\xcf\xfb\xe0\xacA.:\x1a'\x08i\xe2\xc4" + b"\r\x85+\x0e\x85\xaeUZ\xd4\xb3\x97zO" ) - self.assertEqual(HS256.sign('some key', 'foo'), sig) - self.assertTrue(HS256.verify('some key', 'foo', sig) is True) - self.assertTrue(HS256.verify('some key', 'foo', sig + '!') is False) + self.assertEqual(HS256.sign(b'some key', b'foo'), sig) + self.assertTrue(HS256.verify(b'some key', b'foo', sig) is True) + self.assertTrue(HS256.verify(b'some key', b'foo', sig + b'!') is False) class JWARSTest(unittest.TestCase): @@ -71,32 +71,33 @@ class JWARSTest(unittest.TestCase): def test_sign_no_private_part(self): from acme.jose.jwa import RS256 self.assertRaises( - errors.Error, RS256.sign, RSA512_KEY.public_key(), 'foo') + errors.Error, RS256.sign, RSA512_KEY.public_key(), b'foo') def test_sign_key_too_small(self): from acme.jose.jwa import RS256 from acme.jose.jwa import PS256 - self.assertRaises(errors.Error, RS256.sign, RSA256_KEY, 'foo') - self.assertRaises(errors.Error, PS256.sign, RSA256_KEY, 'foo') + self.assertRaises(errors.Error, RS256.sign, RSA256_KEY, b'foo') + self.assertRaises(errors.Error, PS256.sign, RSA256_KEY, b'foo') def test_rs(self): from acme.jose.jwa import RS256 sig = ( - '|\xc6\xb2\xa4\xab(\x87\x99\xfa*:\xea\xf8\xa0N&}\x9f\x0f\xc0O' - '\xc6t\xa3\xe6\xfa\xbb"\x15Y\x80Y\xe0\x81\xb8\x88)\xba\x0c\x9c' - '\xa4\x99\x1e\x19&\xd8\xc7\x99S\x97\xfc\x85\x0cOV\xe6\x07\x99' - '\xd2\xb9.>}\xfd' + b'|\xc6\xb2\xa4\xab(\x87\x99\xfa*:\xea\xf8\xa0N&}\x9f\x0f\xc0O' + b'\xc6t\xa3\xe6\xfa\xbb"\x15Y\x80Y\xe0\x81\xb8\x88)\xba\x0c\x9c' + b'\xa4\x99\x1e\x19&\xd8\xc7\x99S\x97\xfc\x85\x0cOV\xe6\x07\x99' + b'\xd2\xb9.>}\xfd' ) - self.assertEqual(RS256.sign(RSA512_KEY, 'foo'), sig) - self.assertTrue(RS256.verify(RSA512_KEY.public_key(), 'foo', sig)) + self.assertEqual(RS256.sign(RSA512_KEY, b'foo'), sig) + self.assertTrue(RS256.verify(RSA512_KEY.public_key(), b'foo', sig)) self.assertFalse(RS256.verify( - RSA512_KEY.public_key(), 'foo', sig + '!')) + RSA512_KEY.public_key(), b'foo', sig + b'!')) def test_ps(self): from acme.jose.jwa import PS256 - sig = PS256.sign(RSA1024_KEY, 'foo') - self.assertTrue(PS256.verify(RSA1024_KEY.public_key(), 'foo', sig)) - self.assertFalse(PS256.verify(RSA1024_KEY.public_key(), 'foo', sig + '!')) + sig = PS256.sign(RSA1024_KEY, b'foo') + self.assertTrue(PS256.verify(RSA1024_KEY.public_key(), b'foo', sig)) + self.assertFalse(PS256.verify( + RSA1024_KEY.public_key(), b'foo', sig + b'!')) if __name__ == '__main__': diff --git a/acme/acme/jose/jwk.py b/acme/acme/jose/jwk.py index 2b48d56e6..d9b903eb0 100644 --- a/acme/acme/jose/jwk.py +++ b/acme/acme/jose/jwk.py @@ -9,7 +9,8 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.asymmetric import rsa -from acme.jose import b64 +import six + from acme.jose import errors from acme.jose import json_util from acme.jose import util @@ -87,7 +88,7 @@ class JWK(json_util.TypedJSONObjectWithFields): key, cls.cryptography_key_types): raise errors.Error("Unable to deserialize {0} into {1}".format( key.__class__, cls.__class__)) - for jwk_cls in cls.TYPES.itervalues(): + for jwk_cls in six.itervalues(cls.TYPES): if isinstance(key, jwk_cls.cryptography_key_types): return jwk_cls(key=key) raise errors.Error("Unsupported algorithm: {0}".format(key.__class__)) @@ -127,11 +128,11 @@ class JWKOct(JWK): # algorithm intended to be used with the key, unless the # application uses another means or convention to determine # the algorithm used. - return {'k': self.key} + return {'k': json_util.encode_b64jose(self.key)} @classmethod def fields_from_json(cls, jobj): - return cls(key=jobj['k']) + return cls(key=json_util.decode_b64jose(jobj['k'])) def public_key(self): return self @@ -158,18 +159,25 @@ class JWKRSA(JWK): @classmethod def _encode_param(cls, data): + """Encode Base64urlUInt. + + :type data: long + :rtype: unicode + + """ def _leading_zeros(arg): if len(arg) % 2: return '0' + arg return arg - return b64.b64encode(binascii.unhexlify( + return json_util.encode_b64jose(binascii.unhexlify( _leading_zeros(hex(data)[2:].rstrip('L')))) @classmethod def _decode_param(cls, data): + """Decode Base64urlUInt.""" try: - return long(binascii.hexlify(json_util.decode_b64jose(data)), 16) + return int(binascii.hexlify(json_util.decode_b64jose(data)), 16) except ValueError: # invalid literal for long() with base 16 raise errors.DeserializationError() @@ -198,17 +206,20 @@ class JWKRSA(JWK): raise errors.Error( "Some private parameters are missing: {0}".format( all_params)) - p, q, dp, dq, qi = tuple(cls._decode_param(x) for x in all_params) + p, q, dp, dq, qi = tuple( + cls._decode_param(x) for x in all_params) # TODO: check for oth else: - p, q = rsa.rsa_recover_prime_factors(n, e, d) # cryptography>=0.8 + # cryptography>=0.8 + p, q = rsa.rsa_recover_prime_factors(n, e, d) dp = rsa.rsa_crt_dmp1(d, p) dq = rsa.rsa_crt_dmq1(d, q) qi = rsa.rsa_crt_iqmp(p, q) key = rsa.RSAPrivateNumbers( - p, q, d, dp, dq, qi, public_numbers).private_key(default_backend()) + p, q, d, dp, dq, qi, public_numbers).private_key( + default_backend()) return cls(key=key) @@ -234,4 +245,4 @@ class JWKRSA(JWK): 'qi': private.iqmp, } return dict((key, self._encode_param(value)) - for key, value in params.iteritems()) + for key, value in six.iteritems(params)) diff --git a/acme/acme/jose/jwk_test.py b/acme/acme/jose/jwk_test.py index 86674b726..5462af6b0 100644 --- a/acme/acme/jose/jwk_test.py +++ b/acme/acme/jose/jwk_test.py @@ -4,6 +4,7 @@ import unittest from acme import test_util from acme.jose import errors +from acme.jose import json_util from acme.jose import util @@ -29,8 +30,8 @@ class JWKOctTest(unittest.TestCase): def setUp(self): from acme.jose.jwk import JWKOct - self.jwk = JWKOct(key='foo') - self.jobj = {'kty': 'oct', 'k': 'foo'} + self.jwk = JWKOct(key=b'foo') + self.jobj = {'kty': 'oct', 'k': json_util.encode_b64jose(b'foo')} def test_to_partial_json(self): self.assertEqual(self.jwk.to_partial_json(), self.jobj) @@ -45,7 +46,7 @@ class JWKOctTest(unittest.TestCase): def test_load(self): from acme.jose.jwk import JWKOct - self.assertEqual(self.jwk, JWKOct.load('foo')) + self.assertEqual(self.jwk, JWKOct.load(b'foo')) def test_public_key(self): self.assertTrue(self.jwk.public_key() is self.jwk) @@ -64,7 +65,8 @@ class JWKRSATest(unittest.TestCase): 'n': 'm2Fylv-Uz7trgTW8EBHP3FQSMeZs2GNQ6VRo1sIVJEk', } # pylint: disable=protected-access - self.jwk256_not_comparable = JWKRSA(key=RSA256_KEY.public_key()._wrapped) + self.jwk256_not_comparable = JWKRSA( + key=RSA256_KEY.public_key()._wrapped) self.jwk512 = JWKRSA(key=RSA512_KEY.public_key()) self.jwk512json = { 'kty': 'RSA', @@ -91,6 +93,12 @@ class JWKRSATest(unittest.TestCase): self.jwk256_not_comparable.key, util.ComparableRSAKey)) self.assertEqual(self.jwk256, self.jwk256_not_comparable) + def test_encode_param_zero(self): + from acme.jose.jwk import JWKRSA + # pylint: disable=protected-access + # TODO: move encode/decode _param to separate class + self.assertEqual('AA', JWKRSA._encode_param(0)) + def test_equals(self): self.assertEqual(self.jwk256, self.jwk256) self.assertEqual(self.jwk512, self.jwk512) diff --git a/acme/acme/jose/jws.py b/acme/acme/jose/jws.py index 6d1a5db2b..7ecc87bf2 100644 --- a/acme/acme/jose/jws.py +++ b/acme/acme/jose/jws.py @@ -4,6 +4,7 @@ import base64 import sys import OpenSSL +import six from acme.jose import b64 from acme.jose import errors @@ -80,7 +81,7 @@ class Header(json_util.JSONObjectWithFields): def not_omitted(self): """Fields that would not be omitted in the JSON object.""" return dict((name, getattr(self, name)) - for name, field in self._fields.iteritems() + for name, field in six.iteritems(self._fields) if not field.omit(getattr(self, name))) def __add__(self, other): @@ -148,15 +149,22 @@ class Signature(json_util.JSONObjectWithFields): header_cls = Header __slots__ = ('combined',) - protected = json_util.Field( - 'protected', omitempty=True, default='', - decoder=json_util.decode_b64jose, encoder=b64.b64encode) # TODO: utf-8? + protected = json_util.Field('protected', omitempty=True, default='') header = json_util.Field( 'header', omitempty=True, default=header_cls(), decoder=header_cls.from_json) signature = json_util.Field( 'signature', decoder=json_util.decode_b64jose, - encoder=b64.b64encode) + encoder=json_util.encode_b64jose) + + @protected.encoder + def protected(value): # pylint: disable=missing-docstring,no-self-argument + # wrong type guess (Signature, not bytes) | pylint: disable=no-member + return json_util.encode_b64jose(value.encode('utf-8')) + + @protected.decoder + def protected(value): # pylint: disable=missing-docstring,no-self-argument + return json_util.decode_b64jose(value).decode('utf-8') def __init__(self, **kwargs): if 'combined' not in kwargs: @@ -178,6 +186,11 @@ class Signature(json_util.JSONObjectWithFields): kwargs['combined'] = combined return kwargs + @classmethod + def _msg(cls, protected, payload): + return (b64.b64encode(protected.encode('utf-8')) + b'.' + + b64.b64encode(payload)) + def verify(self, payload, key=None): """Verify. @@ -188,8 +201,7 @@ class Signature(json_util.JSONObjectWithFields): key = self.combined.find_key() if key is None else key return self.combined.alg.verify( key=key.key, sig=self.signature, - msg=(b64.b64encode(self.protected) + '.' + - b64.b64encode(payload))) + msg=self._msg(self.protected, payload)) @classmethod def sign(cls, payload, key, alg, include_jwk=True, @@ -220,8 +232,7 @@ class Signature(json_util.JSONObjectWithFields): protected = '' header = cls.header_cls(**header_params) # pylint: disable=star-args - signature = alg.sign(key.key, b64.b64encode(protected) - + '.' + b64.b64encode(payload)) + signature = alg.sign(key.key, cls._msg(protected, payload)) return cls(protected=protected, header=header, signature=signature) @@ -244,7 +255,7 @@ class JWS(json_util.JSONObjectWithFields): """JSON Web Signature. :ivar str payload: JWS Payload. - :ivar str signaturea: JWS Signatures. + :ivar str signature: JWS Signatures. """ __slots__ = ('payload', 'signatures') @@ -272,33 +283,45 @@ class JWS(json_util.JSONObjectWithFields): return self.signatures[0] def to_compact(self): - """Compact serialization.""" + """Compact serialization. + + :rtype: bytes + + """ assert len(self.signatures) == 1 assert 'alg' not in self.signature.header.not_omitted() # ... it must be in protected - return '{0}.{1}.{2}'.format( - b64.b64encode(self.signature.protected), - b64.b64encode(self.payload), + return ( + b64.b64encode(self.signature.protected.encode('utf-8')) + + b'.' + + b64.b64encode(self.payload) + + b'.' + b64.b64encode(self.signature.signature)) @classmethod def from_compact(cls, compact): - """Compact deserialization.""" + """Compact deserialization. + + :param bytes compact: + + """ try: - protected, payload, signature = compact.split('.') + protected, payload, signature = compact.split(b'.') except ValueError: raise errors.DeserializationError( 'Compact JWS serialization should comprise of exactly' ' 3 dot-separated components') - sig = cls.signature_cls(protected=json_util.decode_b64jose(protected), - signature=json_util.decode_b64jose(signature)) - return cls(payload=json_util.decode_b64jose(payload), signatures=(sig,)) + + sig = cls.signature_cls( + protected=b64.b64decode(protected).decode('utf-8'), + signature=b64.b64decode(signature)) + return cls(payload=b64.b64decode(payload), signatures=(sig,)) def to_partial_json(self, flat=True): # pylint: disable=arguments-differ assert self.signatures - payload = b64.b64encode(self.payload) + payload = json_util.encode_b64jose(self.payload) if flat and len(self.signatures) == 1: ret = self.signatures[0].to_partial_json() @@ -329,34 +352,36 @@ class CLI(object): def sign(cls, args): """Sign.""" key = args.alg.kty.load(args.key.read()) + args.key.close() if args.protect is None: args.protect = [] if args.compact: args.protect.append('alg') - sig = JWS.sign(payload=sys.stdin.read(), key=key, alg=args.alg, + sig = JWS.sign(payload=sys.stdin.read().encode(), key=key, alg=args.alg, protect=set(args.protect)) if args.compact: - print sig.to_compact() + six.print_(sig.to_compact().decode('utf-8')) else: # JSON - print sig.json_dumps_pretty() + six.print_(sig.json_dumps_pretty()) @classmethod def verify(cls, args): """Verify.""" if args.compact: - sig = JWS.from_compact(sys.stdin.read()) + sig = JWS.from_compact(sys.stdin.read().encode()) else: # JSON try: sig = JWS.json_loads(sys.stdin.read()) except errors.Error as error: - print error + six.print_(error) return -1 if args.key is not None: assert args.kty is not None key = args.kty.load(args.key.read()).public_key() + args.key.close() else: key = None @@ -387,7 +412,7 @@ class CLI(object): parser_sign = subparsers.add_parser('sign') parser_sign.set_defaults(func=cls.sign) parser_sign.add_argument( - '-k', '--key', type=argparse.FileType(), required=True) + '-k', '--key', type=argparse.FileType('rb'), required=True) parser_sign.add_argument( '-a', '--alg', type=cls._alg_type, default=jwa.RS256) parser_sign.add_argument( @@ -396,7 +421,7 @@ class CLI(object): parser_verify = subparsers.add_parser('verify') parser_verify.set_defaults(func=cls.verify) parser_verify.add_argument( - '-k', '--key', type=argparse.FileType(), required=False) + '-k', '--key', type=argparse.FileType('rb'), required=False) parser_verify.add_argument( '--kty', type=cls._kty_type, required=False) diff --git a/acme/acme/jose/jws_test.py b/acme/acme/jose/jws_test.py index 7a3e8cb83..69341f228 100644 --- a/acme/acme/jose/jws_test.py +++ b/acme/acme/jose/jws_test.py @@ -7,8 +7,8 @@ import OpenSSL from acme import test_util -from acme.jose import b64 from acme.jose import errors +from acme.jose import json_util from acme.jose import jwa from acme.jose import jwk @@ -73,7 +73,7 @@ class HeaderTest(unittest.TestCase): self.assertEqual(jobj, {'x5c': [cert_b64, cert_b64]}) self.assertEqual(header, Header.from_json(jobj)) jobj['x5c'][0] = base64.b64encode( - 'xxx' + OpenSSL.crypto.dump_certificate( + b'xxx' + OpenSSL.crypto.dump_certificate( OpenSSL.crypto.FILETYPE_ASN1, CERT)) self.assertRaises(errors.DeserializationError, Header.from_json, jobj) @@ -90,7 +90,7 @@ class SignatureTest(unittest.TestCase): from acme.jose.jws import Header from acme.jose.jws import Signature self.assertEqual( - Signature(signature='foo', header=Header(alg=jwa.RS256)), + Signature(signature=b'foo', header=Header(alg=jwa.RS256)), Signature.from_json( {'signature': 'Zm9v', 'header': {'alg': 'RS256'}})) @@ -109,12 +109,12 @@ class JWSTest(unittest.TestCase): from acme.jose.jws import JWS self.unprotected = JWS.sign( - payload='foo', key=self.privkey, alg=jwa.RS256) + payload=b'foo', key=self.privkey, alg=jwa.RS256) self.protected = JWS.sign( - payload='foo', key=self.privkey, alg=jwa.RS256, + payload=b'foo', key=self.privkey, alg=jwa.RS256, protect=frozenset(['jwk', 'alg'])) self.mixed = JWS.sign( - payload='foo', key=self.privkey, alg=jwa.RS256, + payload=b'foo', key=self.privkey, alg=jwa.RS256, protect=frozenset(['alg'])) def test_pubkey_jwk(self): @@ -134,8 +134,8 @@ class JWSTest(unittest.TestCase): def test_compact_lost_unprotected(self): compact = self.mixed.to_compact() self.assertEqual( - 'eyJhbGciOiAiUlMyNTYifQ.Zm9v.OHdxFVj73l5LpxbFp1AmYX4yJM0Pyb' - '_893n1zQjpim_eLS5J1F61lkvrCrCDErTEJnBGOGesJ72M7b6Ve1cAJA', + b'eyJhbGciOiAiUlMyNTYifQ.Zm9v.OHdxFVj73l5LpxbFp1AmYX4yJM0Pyb' + b'_893n1zQjpim_eLS5J1F61lkvrCrCDErTEJnBGOGesJ72M7b6Ve1cAJA', compact) from acme.jose.jws import JWS @@ -147,7 +147,7 @@ class JWSTest(unittest.TestCase): def test_from_compact_missing_components(self): from acme.jose.jws import JWS - self.assertRaises(errors.DeserializationError, JWS.from_compact, '.') + self.assertRaises(errors.DeserializationError, JWS.from_compact, b'.') def test_json_omitempty(self): protected_jobj = self.protected.to_partial_json(flat=True) @@ -164,10 +164,12 @@ class JWSTest(unittest.TestCase): def test_json_flat(self): jobj_to = { - 'signature': b64.b64encode(self.mixed.signature.signature), - 'payload': b64.b64encode('foo'), + 'signature': json_util.encode_b64jose( + self.mixed.signature.signature), + 'payload': json_util.encode_b64jose(b'foo'), 'header': self.mixed.signature.header, - 'protected': b64.b64encode(self.mixed.signature.protected), + 'protected': json_util.encode_b64jose( + self.mixed.signature.protected.encode('utf-8')), } jobj_from = jobj_to.copy() jobj_from['header'] = jobj_from['header'].to_json() @@ -179,7 +181,7 @@ class JWSTest(unittest.TestCase): def test_json_not_flat(self): jobj_to = { 'signatures': (self.mixed.signature,), - 'payload': b64.b64encode('foo'), + 'payload': json_util.encode_b64jose(b'foo'), } jobj_from = jobj_to.copy() jobj_from['signatures'] = [jobj_to['signatures'][0].to_json()] diff --git a/acme/acme/jose/util.py b/acme/acme/jose/util.py index eebbe7468..fd58a9e97 100644 --- a/acme/acme/jose/util.py +++ b/acme/acme/jose/util.py @@ -3,6 +3,7 @@ import collections from cryptography.hazmat.primitives.asymmetric import rsa import OpenSSL +import six class abstractclassmethod(classmethod): @@ -156,7 +157,8 @@ class ImmutableMap(collections.Mapping, collections.Hashable): def __repr__(self): return '{0}({1})'.format(self.__class__.__name__, ', '.join( - '{0}={1!r}'.format(key, value) for key, value in self.iteritems())) + '{0}={1!r}'.format(key, value) + for key, value in six.iteritems(self))) class frozendict(collections.Mapping, collections.Hashable): @@ -174,7 +176,7 @@ class frozendict(collections.Mapping, collections.Hashable): # TODO: support generators/iterators object.__setattr__(self, '_items', items) - object.__setattr__(self, '_keys', tuple(sorted(items.iterkeys()))) + object.__setattr__(self, '_keys', tuple(sorted(six.iterkeys(items)))) def __getitem__(self, key): return self._items[key] @@ -185,8 +187,11 @@ class frozendict(collections.Mapping, collections.Hashable): def __len__(self): return len(self._items) + def _sorted_items(self): + return tuple((key, self[key]) for key in self._keys) + def __hash__(self): - return hash(tuple((key, value) for key, value in self.items())) + return hash(self._sorted_items()) def __getattr__(self, name): try: @@ -198,5 +203,5 @@ class frozendict(collections.Mapping, collections.Hashable): raise AttributeError("can't set attribute") def __repr__(self): - return 'frozendict({0})'.format(', '.join( - '{0}={1!r}'.format(key, value) for key, value in self.iteritems())) + return 'frozendict({0})'.format(', '.join('{0}={1!r}'.format( + key, value) for key, value in self._sorted_items())) diff --git a/acme/acme/jose/util_test.py b/acme/acme/jose/util_test.py index 1bde9ebd9..4cdd9127f 100644 --- a/acme/acme/jose/util_test.py +++ b/acme/acme/jose/util_test.py @@ -2,6 +2,8 @@ import functools import unittest +import six + from acme import test_util @@ -168,13 +170,13 @@ class frozendictTest(unittest.TestCase): # pylint: disable=invalid-name def test_init_other_raises_type_error(self): from acme.jose.util import frozendict # specifically fail for generators... - self.assertRaises(TypeError, frozendict, {'a': 'b'}.iteritems()) + self.assertRaises(TypeError, frozendict, six.iteritems({'a': 'b'})) def test_len(self): self.assertEqual(2, len(self.fdict)) def test_hash(self): - self.assertEqual(1278944519403861804, hash(self.fdict)) + self.assertTrue(isinstance(hash(self.fdict), int)) def test_getattr_proxy(self): self.assertEqual(1, self.fdict.x) diff --git a/acme/setup.py b/acme/setup.py index d83131d2a..481a35eb6 100644 --- a/acme/setup.py +++ b/acme/setup.py @@ -14,6 +14,7 @@ install_requires = [ 'PyOpenSSL', 'pytz', 'requests', + 'six', 'werkzeug', ]