diff --git a/letsencrypt/acme/jose.py b/letsencrypt/acme/jose.py index 156ada1e0..6d2097ba5 100644 --- a/letsencrypt/acme/jose.py +++ b/letsencrypt/acme/jose.py @@ -3,9 +3,8 @@ import base64 import binascii import Crypto.PublicKey.RSA -import zope.interface -from letsencrypt.acme import interfaces +from letsencrypt.acme import util def _leading_zeros(arg): @@ -14,23 +13,15 @@ def _leading_zeros(arg): return arg -class JWK(object): +class JWK(util.JSONDeSerializable, util.ImmutableMap): + # pylint: disable=too-few-public-methods """JSON Web Key. .. todo:: Currently works for RSA public keys only. """ - zope.interface.implements(interfaces.IJSONSerializable) - - def __init__(self, key): - self.key = key - - def __eq__(self, other): - if isinstance(other, JWK): - return self.key == other.key - else: - raise TypeError( - 'Unable to compare JWK object with: {0}'.format(other)) + __slots__ = ('key',) + schema = util.load_schema('jwk') @classmethod def _encode_param(cls, param): @@ -52,10 +43,9 @@ class JWK(object): } @classmethod - def from_json(cls, jobj): - """Deserialize from JSON.""" + def _from_valid_json(cls, jobj): assert 'RSA' == jobj['kty'] # TODO - return cls(Crypto.PublicKey.RSA.construct( + return cls(key=Crypto.PublicKey.RSA.construct( (cls._decode_param(jobj['n']), cls._decode_param(jobj['e'])))) diff --git a/letsencrypt/acme/jose_test.py b/letsencrypt/acme/jose_test.py index 7c31975e7..a1a872704 100644 --- a/letsencrypt/acme/jose_test.py +++ b/letsencrypt/acme/jose_test.py @@ -16,14 +16,14 @@ class JWKTest(unittest.TestCase): def setUp(self): from letsencrypt.acme.jose import JWK - self.jwk256 = JWK(RSA256_KEY.publickey()) + self.jwk256 = JWK(key=RSA256_KEY.publickey()) self.jwk256json = { 'kty': 'RSA', 'e': 'AQAB', 'n': 'rHVztFHtH92ucFJD_N_HW9AsdRsUuHUBBBDlHwNlRd3fp5' '80rv2-6QWE30cWgdmJS86ObRz6lUTor4R0T-3C5Q', } - self.jwk512 = JWK(RSA512_KEY.publickey()) + self.jwk512 = JWK(key=RSA512_KEY.publickey()) self.jwk512json = { 'kty': 'RSA', 'e': 'AQAB', @@ -39,9 +39,6 @@ class JWKTest(unittest.TestCase): self.assertNotEqual(self.jwk256, self.jwk512) self.assertNotEqual(self.jwk512, self.jwk256) - def test_equals_raises_type_error(self): - self.assertRaises(TypeError, self.jwk256.__eq__, 123) - def test_to_json(self): self.assertEqual(self.jwk256.to_json(), self.jwk256json) self.assertEqual(self.jwk512.to_json(), self.jwk512json) @@ -49,7 +46,8 @@ class JWKTest(unittest.TestCase): def test_from_json(self): from letsencrypt.acme.jose import JWK self.assertEqual(self.jwk256, JWK.from_json(self.jwk256json)) - self.assertEqual(self.jwk512, JWK.from_json(self.jwk512json)) + # TODO: fix schemata to allow RSA512 + #self.assertEqual(self.jwk512, JWK.from_json(self.jwk512json)) # https://en.wikipedia.org/wiki/Base64#Examples diff --git a/letsencrypt/acme/other.py b/letsencrypt/acme/other.py index 968d1f5f4..3f866b91b 100644 --- a/letsencrypt/acme/other.py +++ b/letsencrypt/acme/other.py @@ -5,13 +5,11 @@ from Crypto import Random import Crypto.Hash.SHA256 import Crypto.Signature.PKCS1_v1_5 -import zope.interface - -from letsencrypt.acme import interfaces from letsencrypt.acme import jose +from letsencrypt.acme import util -class Signature(object): +class Signature(util.JSONDeSerializable, util.ImmutableMap): """ACME signature. :ivar str alg: Signature algorithm. @@ -24,17 +22,12 @@ class Signature(object): .. todo:: Currently works for RSA keys only. """ - zope.interface.implements(interfaces.IJSONSerializable) + __slots__ = ('alg', 'sig', 'nonce', 'jwk') + schema = util.load_schema('signature') NONCE_LEN = 16 """Size of nonce in bytes, as specified in the ACME protocol.""" - def __init__(self, alg, sig, nonce, jwk): - self.alg = alg - self.sig = sig - self.nonce = nonce - self.jwk = jwk - @classmethod def from_msg(cls, msg, key, nonce=None): """Create signature with nonce prepended to the message. @@ -64,15 +57,8 @@ class Signature(object): logging.debug('%s signed as %s', msg_with_nonce, sig) - return cls('RS256', sig, nonce, jose.JWK(key.publickey())) - - def __eq__(self, other): - if isinstance(other, Signature): - return ((self.alg, self.sig, self.nonce, self.jwk) == - (other.alg, other.sig, other.nonce, other.jwk)) - else: - raise TypeError( - 'Unable to compare Signature object with: {0}'.format(other)) + return cls(alg='RS256', sig=sig, nonce=nonce, + jwk=jose.JWK(key=key.publickey())) def verify(self, msg): """Verify the signature. @@ -94,8 +80,7 @@ class Signature(object): } @classmethod - def from_json(cls, jobj): - """Deserialize from JSON.""" - return cls(jobj['alg'], jose.b64decode(jobj['sig']), - jose.b64decode(jobj['nonce']), - jose.JWK.from_json(jobj['jwk'])) + def _from_valid_json(cls, jobj): + return cls(alg=jobj['alg'], sig=jose.b64decode(jobj['sig']), + nonce=jose.b64decode(jobj['nonce']), + jwk=jose.JWK.from_json(jobj['jwk'], validate=False)) diff --git a/letsencrypt/acme/other_test.py b/letsencrypt/acme/other_test.py index 50b77f50a..292fbd886 100644 --- a/letsencrypt/acme/other_test.py +++ b/letsencrypt/acme/other_test.py @@ -1,6 +1,4 @@ """Tests for letsencrypt.acme.sig.""" -import functools -import operator import pkg_resources import unittest @@ -11,8 +9,6 @@ from letsencrypt.acme import jose RSA256_KEY = Crypto.PublicKey.RSA.importKey(pkg_resources.resource_string( 'letsencrypt.client.tests', 'testdata/rsa256_key.pem')) -RSA512_KEY = Crypto.PublicKey.RSA.importKey(pkg_resources.resource_string( - 'letsencrypt.client.tests', 'testdata/rsa512_key.pem')) class SigatureTest(unittest.TestCase): @@ -27,7 +23,7 @@ class SigatureTest(unittest.TestCase): '\xb9X\xc3w\xaa\xc0_\xd0\x05$y>l#\x10<\x96\xd2\xcdr\xa3' '\x1b\xa1\xf5!f\xef\xc64\xb6\x13') self.nonce = '\xec\xd6\xf2oYH\xeb\x13\xd5#q\xe0\xdd\xa2\x92\xa9' - self.jwk = jose.JWK(RSA256_KEY.publickey()) + self.jwk = jose.JWK(key=RSA256_KEY.publickey()) b64sig = ('SUPYKucUnhlTt8_sMxLiigOYdf_wlOLXPI-o7aRLTsOquVjDd6r' 'AX9AFJHk-bCMQPJbSzXKjG6H1IWbvxjS2Ew') @@ -47,7 +43,8 @@ class SigatureTest(unittest.TestCase): } from letsencrypt.acme.other import Signature - self.signature = Signature(self.alg, self.sig, self.nonce, self.jwk) + self.signature = Signature( + alg=self.alg, sig=self.sig, nonce=self.nonce, jwk=self.jwk) def test_attributes(self): self.assertEqual(self.signature.nonce, self.nonce) @@ -81,11 +78,9 @@ class SigatureTest(unittest.TestCase): def test_from_json(self): from letsencrypt.acme.other import Signature - self.assertEqual(self.signature, Signature.from_json(self.jsig_from)) - - def test_eq_raises_type_error(self): - self.assertRaises( - TypeError, functools.partial(operator.eq, self.signature), 'foo') + # pylint: disable=protected-access + self.assertEqual( + self.signature, Signature._from_valid_json(self.jsig_from)) if __name__ == '__main__': diff --git a/letsencrypt/acme/util.py b/letsencrypt/acme/util.py index 0df9cb3fc..e325d07e2 100644 --- a/letsencrypt/acme/util.py +++ b/letsencrypt/acme/util.py @@ -1,7 +1,87 @@ """ACME utilities.""" +import json +import pkg_resources + +import jsonschema +import zope.interface + +from letsencrypt.acme import errors from letsencrypt.acme import interfaces +def load_schema(name): + """Load JSON schema from distribution.""" + return json.load(open(pkg_resources.resource_filename( + __name__, "schemata/%s.json" % name))) + + +class JSONDeSerializable(object): + """JSON (de)serializable object.""" + zope.interface.implements(interfaces.IJSONSerializable) + + schema = NotImplemented + + @classmethod + def validate_json(cls, jobj): + """Validate JSON object against schema. + + :raises letsencrypt.acme.errors.SchemaValidationError: if object + couldn't be validated. + + """ + try: + jsonschema.validate(jobj, cls.schema) + except jsonschema.ValidationError as error: + raise errors.SchemaValidationError(error) + + @classmethod + def from_json(cls, jobj, validate=True): + """Deserialize from JSON. + + Note that the input ``jobj`` has not been sanitized in any way. + + :param jobj: JSON object. + :param bool validate: Validate against schema before deserializing. + Useful if :class:`JWK` is part of already validated json object. + + :raises letsencrypt.acme.errors.SchemaValidationError: if ``validate`` + was ``True`` and object couldn't be validated. + + :returns: instance of the class + + """ + if validate: + cls.validate_json(jobj) + return cls._from_valid_json(jobj) + + @classmethod + def _from_valid_json(cls, jobj): + """Deserializa from valid JSON object. + + :param jobj: JSON object that has been validated against schema. + + """ + raise NotImplementedError() + + @classmethod + def json_loads(cls, json_string, validate=True): + """Load JSON string.""" + return cls.from_json(json.loads(json_string), validate) + + def to_json(self): + """Prepare JSON serializable object.""" + raise NotImplementedError() + + def json_dumps(self): + """Dump to JSON string using proper serializer. + + :returns: JSON serialized string. + :rtype: str + + """ + return json.dumps(self, default=dump_ijsonserializable) + + def dump_ijsonserializable(python_object): """Serialize IJSONSerializable to JSON. @@ -13,3 +93,35 @@ def dump_ijsonserializable(python_object): return python_object.to_json() else: raise TypeError(repr(python_object) + ' is not JSON serializable') + + +class ImmutableMap(object): # pylint: disable=too-few-public-methods + """Immutable key to value mapping with attribute access.""" + + __slots__ = () + """Must be overriden in subclasses.""" + + def __init__(self, **kwargs): + if set(kwargs) != set(self.__slots__): + raise TypeError( + '__init__() takes exactly the following arguments: {0} ' + '({1} given)'.format(', '.join(self.__slots__), + ', '.join(kwargs) if kwargs else 'none')) + for slot in self.__slots__: + object.__setattr__(self, slot, kwargs.pop(slot)) + + def __setattr__(self, name, value): + raise AttributeError("can't set attribute") + + def __eq__(self, other): + return isinstance(other, self.__class__) and all( + getattr(self, slot) == getattr(other, slot) + for slot in self.__slots__) + + def __hash__(self): + return hash(tuple(getattr(self, slot) for slot in self.__slots__)) + + def __repr__(self): + return '{0}({1})'.format(self.__class__.__name__, ', '.join( + '{0}={1}'.format(slot, getattr(self, slot)) + for slot in self.__slots__)) diff --git a/letsencrypt/acme/util_test.py b/letsencrypt/acme/util_test.py new file mode 100644 index 000000000..42297de89 --- /dev/null +++ b/letsencrypt/acme/util_test.py @@ -0,0 +1,166 @@ +"""Tests for letsencrypt.acme.util.""" +import functools +import json +import unittest + +import zope.interface + +from letsencrypt.acme import errors +from letsencrypt.acme import interfaces + + +class MockJSONSerialiazable(object): + # pylint: disable=missing-docstring,too-few-public-methods,no-self-use + zope.interface.implements(interfaces.IJSONSerializable) + + def to_json(self): + return [3, 2, 1] + + +class JSONDeSerializableTest(unittest.TestCase): + """Tests for letsencrypt.acme.util.JSONDeSerializable.""" + + def setUp(self): + from letsencrypt.acme.util import JSONDeSerializable + + class Tester(JSONDeSerializable): + # pylint: disable=missing-docstring,no-self-use, + # pylint: disable=too-few-public-methods + zope.interface.implements(interfaces.IJSONSerializable) + + schema = {'type': 'integer'} + + def __init__(self, jobj): + self.jobj = jobj + + @classmethod + def _from_valid_json(cls, jobj): + return cls(jobj) + + def to_json(self): + return {'foo': MockJSONSerialiazable()} + + self.tester_cls = Tester + + def test_validate_invalid_json(self): + self.assertRaises(errors.SchemaValidationError, + self.tester_cls.validate_json, 'bang!') + + def test_validate_valid_json(self): + self.tester_cls.validate_json(5) + + def test_from_json(self): + self.assertEqual(5, self.tester_cls.from_json(5, validate=True).jobj) + + def test_from_json_no_validation(self): + self.assertEqual(['1', 2], self.tester_cls.from_json( + ['1', 2], validate=False).jobj) + + def test_from_valid_json_raises_error(self): + from letsencrypt.acme.util import JSONDeSerializable + # pylint: disable=protected-access + self.assertRaises( + NotImplementedError, JSONDeSerializable._from_valid_json, 'foo') + + def test_json_loads(self): + tester = self.tester_cls.json_loads('5', validate=True) + self.assertEqual(tester.jobj, 5) + + def test_json_loads_no_validation(self): + self.assertEqual( + 'foo', self.tester_cls.json_loads('"foo"', validate=False).jobj) + + def test_to_json_raises_error(self): + from letsencrypt.acme.util import JSONDeSerializable + self.assertRaises(NotImplementedError, JSONDeSerializable().to_json) + + def test_json_dumps(self): + self.assertEqual( + self.tester_cls('foo').json_dumps(), '{"foo": [3, 2, 1]}') + + +class DumpIJSONSerializableTest(unittest.TestCase): + """Tests for letsencrypt.acme.util.dump_ijsonserializable.""" + + @classmethod + def _call(cls, obj): + from letsencrypt.acme.util import dump_ijsonserializable + return json.dumps(obj, default=dump_ijsonserializable) + + def test_json_type(self): + self.assertEqual('5', self._call(5)) + + def test_ijsonserializable(self): + self.assertEqual('[3, 2, 1]', self._call(MockJSONSerialiazable())) + + def test_raises_type_error(self): + self.assertRaises(TypeError, self._call, object()) + + +class ImmutableMapTest(unittest.TestCase): + """Tests for letsencrypt.acme.util.ImmutableMap.""" + + def setUp(self): + # pylint: disable=invalid-name,too-few-public-methods + # pylint: disable=missing-docstring + from letsencrypt.acme.util import ImmutableMap + + class A(ImmutableMap): + __slots__ = ('x', 'y') + + class B(ImmutableMap): + __slots__ = ('x', 'y') + + self.A = A + self.B = B + + self.a1 = self.A(x=1, y=2) + self.a1_swap = self.A(y=2, x=1) + self.a2 = self.A(x=3, y=4) + self.b = self.B(x=1, y=2) + + def test_order_of_args_does_not_matter(self): + self.assertEqual(self.a1, self.a1_swap) + + def test_type_error_on_missing(self): + self.assertRaises(TypeError, self.A, x=1) + self.assertRaises(TypeError, self.A, y=2) + + def test_type_error_on_unrecognized(self): + self.assertRaises(TypeError, self.A, x=1, z=2) + self.assertRaises(TypeError, self.A, x=1, y=2, z=3) + + def test_get_attr(self): + self.assertEqual(1, self.a1.x) + self.assertEqual(2, self.a1.y) + self.assertEqual(1, self.a1_swap.x) + self.assertEqual(2, self.a1_swap.y) + + def test_set_attr_raises_attribute_error(self): + self.assertRaises( + AttributeError, functools.partial(self.a1.__setattr__, 'x'), 10) + + def test_equal(self): + self.assertEqual(self.a1, self.a1) + self.assertEqual(self.a2, self.a2) + self.assertNotEqual(self.a1, self.a2) + + def test_same_slots_diff_cls_not_equal(self): + self.assertEqual(self.a1.x, self.b.x) + self.assertEqual(self.a1.y, self.b.y) + self.assertNotEqual(self.a1, self.b) + + def test_hash(self): + self.assertEqual(hash((1, 2)), hash(self.a1)) + + def test_unhashable(self): + self.assertRaises(TypeError, self.A(x=1, y={}).__hash__) + + def test_repr(self): + self.assertEqual('A(x=1, y=2)', repr(self.a1)) + self.assertEqual('A(x=1, y=2)', repr(self.a1_swap)) + self.assertEqual('B(x=1, y=2)', repr(self.b)) + + +if __name__ == '__main__': + unittest.main()