From dad799d4284ea68ad18927823723db551749ddc9 Mon Sep 17 00:00:00 2001 From: Jakub Warmuz Date: Wed, 11 Feb 2015 16:08:55 +0000 Subject: [PATCH] acme.messages.Message.get_msg_cls --- letsencrypt/acme/messages.py | 38 ++++++++++++++++++++----------- letsencrypt/acme/messages_test.py | 4 ++-- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/letsencrypt/acme/messages.py b/letsencrypt/acme/messages.py index c91c95f59..de14dac96 100644 --- a/letsencrypt/acme/messages.py +++ b/letsencrypt/acme/messages.py @@ -54,6 +54,30 @@ class Message(util.JSONDeSerializable, util.ImmutableMap): """ raise NotImplementedError() + @classmethod + def get_msg_cls(cls, jobj): + """Get the registered class for ``jobj``.""" + if cls in cls.TYPES.itervalues(): + # cls is already registered Message type, force to use it + # so that, e.g Revocation.from_json(jobj) fails if + # jobj["type"] != "revocation". + return cls + + if not isinstance(jobj, dict): + raise errors.ValidationError( + "{0} is not a dictionary object".format(jobj)) + try: + msg_type = jobj["type"] + except KeyError: + raise errors.ValidationError("missing type field") + + try: + msg_cls = cls.TYPES[msg_type] + except KeyError: + raise errors.UnrecognizedMessageTypeError(msg_type) + + return msg_cls + @classmethod def from_json(cls, jobj, validate=True): """Deserialize validated ACME message from JSON string. @@ -69,19 +93,7 @@ class Message(util.JSONDeSerializable, util.ImmutableMap): :rtype: subclass of :class:`Message` """ - if not isinstance(jobj, dict): - raise errors.ValidationError( - "{0} is not a dictionary object".format(jobj)) - try: - msg_type = jobj["type"] - except KeyError: - raise errors.ValidationError("missing type field") - - try: - msg_cls = cls.TYPES[msg_type] - except KeyError: - raise errors.UnrecognizedMessageTypeError(msg_type) - + msg_cls = cls.get_msg_cls(jobj) if validate: msg_cls.validate_json(jobj) # pylint: disable=protected-access diff --git a/letsencrypt/acme/messages_test.py b/letsencrypt/acme/messages_test.py index 0820c8e73..b1c2f9a3c 100644 --- a/letsencrypt/acme/messages_test.py +++ b/letsencrypt/acme/messages_test.py @@ -413,8 +413,8 @@ class RevocationTest(unittest.TestCase): self.assertEqual(self.msg.to_json(), self.jmsg) def test_from_json(self): - from letsencrypt.acme.messages import Error - self.assertEqual(Error.from_json(self.jmsg), self.msg) + from letsencrypt.acme.messages import Revocation + self.assertEqual(Revocation.from_json(self.jmsg), self.msg) class RevocationRequestTest(unittest.TestCase):