diff --git a/letsencrypt/crypto_util.py b/letsencrypt/crypto_util.py index 79cd24ed6..777b4d006 100644 --- a/letsencrypt/crypto_util.py +++ b/letsencrypt/crypto_util.py @@ -201,29 +201,26 @@ def valid_privkey(privkey): return False -def _pyopenssl_load(data, method, types=( - OpenSSL.crypto.FILETYPE_PEM, OpenSSL.crypto.FILETYPE_ASN1)): - openssl_errors = [] - for filetype in types: - try: - return method(filetype, data), filetype - except OpenSSL.crypto.Error as error: # TODO: anything else? - openssl_errors.append(error) - raise errors.Error("Unable to load: {0}".format(",".join( - str(error) for error in openssl_errors))) - - def pyopenssl_load_certificate(data): """Load PEM/DER certificate. :raises errors.Error: """ - return _pyopenssl_load(data, OpenSSL.crypto.load_certificate) + + openssl_errors = [] + + for file_type in (OpenSSL.crypto.FILETYPE_PEM, OpenSSL.crypto.FILETYPE_ASN1): + try: + return OpenSSL.crypto.load_certificate(file_type, data), file_type + except OpenSSL.crypto.Error as error: # TODO: other errors? + openssl_errors.append(error) + raise errors.Error("Unable to load: {0}".format(",".join( + str(error) for error in openssl_errors))) -def _get_sans_from_cert_or_req( - cert_or_req_str, load_func, typ=OpenSSL.crypto.FILETYPE_PEM): +def _get_sans_from_cert_or_req(cert_or_req_str, load_func, + typ=OpenSSL.crypto.FILETYPE_PEM): try: cert_or_req = load_func(typ, cert_or_req_str) except OpenSSL.crypto.Error as error: diff --git a/letsencrypt/tests/crypto_util_test.py b/letsencrypt/tests/crypto_util_test.py index b4d2aa394..2e04c748a 100644 --- a/letsencrypt/tests/crypto_util_test.py +++ b/letsencrypt/tests/crypto_util_test.py @@ -8,6 +8,7 @@ import OpenSSL import mock import zope.component +from letsencrypt import errors from letsencrypt import interfaces from letsencrypt.tests import test_util @@ -213,5 +214,23 @@ class GetSANsFromCSRTest(unittest.TestCase): [], self._call(test_util.load_vector('csr-nosans.pem'))) +class CertLoaderTest(unittest.TestCase): + """Tests for letsencrypt.crypto_util.pyopenssl_load_certificate""" + + def test_load_valid_cert(self): + from letsencrypt.crypto_util import pyopenssl_load_certificate + + cert, file_type = pyopenssl_load_certificate(CERT) + self.assertEqual(cert.digest('sha1'), + OpenSSL.crypto.load_certificate(file_type, CERT).digest('sha1')) + + def test_load_invalid_cert(self): + from letsencrypt.crypto_util import pyopenssl_load_certificate + bad_cert_data = CERT.replace("BEGIN CERTIFICATE", "ASDFASDFASDF!!!") + + with self.assertRaises(errors.Error): + pyopenssl_load_certificate(bad_cert_data) + + if __name__ == '__main__': unittest.main() # pragma: no cover