diff --git a/certbot/crypto_util.py b/certbot/crypto_util.py index 6b1b8426c..1e831dd8f 100644 --- a/certbot/crypto_util.py +++ b/certbot/crypto_util.py @@ -296,6 +296,32 @@ def get_sans_from_csr(csr, typ=OpenSSL.crypto.FILETYPE_PEM): csr, OpenSSL.crypto.load_certificate_request, typ) +def _get_names_from_cert_or_req(cert_or_req, load_func, typ): + loaded_cert_or_req = _load_cert_or_req(cert_or_req, load_func, typ) + common_name = loaded_cert_or_req.get_subject().CN + # pylint: disable=protected-access + sans = acme_crypto_util._pyopenssl_cert_or_req_san(loaded_cert_or_req) + + if common_name is None: + return sans + else: + return [common_name] + [d for d in sans if d != common_name] + + +def get_names_from_cert(csr, typ=OpenSSL.crypto.FILETYPE_PEM): + """Get a list of domains from a cert, including the CN if it is set. + + :param str cert: Certificate (encoded). + :param typ: `OpenSSL.crypto.FILETYPE_PEM` or `OpenSSL.crypto.FILETYPE_ASN1` + + :returns: A list of domain names. + :rtype: list + + """ + return _get_names_from_cert_or_req( + csr, OpenSSL.crypto.load_certificate, typ) + + def get_names_from_csr(csr, typ=OpenSSL.crypto.FILETYPE_PEM): """Get a list of domains from a CSR, including the CN if it is set. @@ -306,13 +332,8 @@ def get_names_from_csr(csr, typ=OpenSSL.crypto.FILETYPE_PEM): :rtype: list """ - loaded_csr = _load_cert_or_req( + return _get_names_from_cert_or_req( csr, OpenSSL.crypto.load_certificate_request, typ) - # Use a set to avoid duplication with CN and Subject Alt Names - domains = set(d for d in (loaded_csr.get_subject().CN,) if d is not None) - # pylint: disable=protected-access - domains.update(acme_crypto_util._pyopenssl_cert_or_req_san(loaded_csr)) - return list(domains) def dump_pyopenssl_chain(chain, filetype=OpenSSL.crypto.FILETYPE_PEM): diff --git a/certbot/storage.py b/certbot/storage.py index b0c8245d3..60886e306 100644 --- a/certbot/storage.py +++ b/certbot/storage.py @@ -616,7 +616,7 @@ class RenewableCert(object): # pylint: disable=too-many-instance-attributes if target is None: raise errors.CertStorageError("could not find cert file") with open(target) as f: - return crypto_util.get_sans_from_cert(f.read()) + return crypto_util.get_names_from_cert(f.read()) def autodeployment_is_enabled(self): """Is automatic deployment enabled for this cert? diff --git a/certbot/tests/crypto_util_test.py b/certbot/tests/crypto_util_test.py index fa88e89e7..5a592bbb1 100644 --- a/certbot/tests/crypto_util_test.py +++ b/certbot/tests/crypto_util_test.py @@ -273,6 +273,32 @@ class GetSANsFromCSRTest(unittest.TestCase): [], self._call(test_util.load_vector('csr-nosans.pem'))) +class GetNamesFromCertTest(unittest.TestCase): + """Tests for certbot.crypto_util.get_names_from_cert.""" + + @classmethod + def _call(cls, *args, **kwargs): + from certbot.crypto_util import get_names_from_cert + return get_names_from_cert(*args, **kwargs) + + def test_single(self): + self.assertEqual( + ['example.com'], + self._call(test_util.load_vector('cert.pem'))) + + def test_san(self): + self.assertEqual( + ['example.com', 'www.example.com'], + self._call(test_util.load_vector('cert-san.pem'))) + + def test_common_name_sans_order(self): + # Tests that the common name comes first + # followed by the SANS in alphabetical order + self.assertEqual( + ['example.com'] + ['{0}.example.com'.format(c) for c in 'abcd'], + self._call(test_util.load_vector('cert-5sans.pem'))) + + class GetNamesFromCSRTest(unittest.TestCase): """Tests for certbot.crypto_util.get_names_from_csr.""" @classmethod diff --git a/certbot/tests/storage_test.py b/certbot/tests/storage_test.py index 0c88d3d55..0d907eca3 100644 --- a/certbot/tests/storage_test.py +++ b/certbot/tests/storage_test.py @@ -84,18 +84,20 @@ class BaseRenewableCertTest(unittest.TestCase): def tearDown(self): shutil.rmtree(self.tempdir) + def _write_out_kind(self, kind, ver, value=None): + link = getattr(self.test_rc, kind) + if os.path.lexists(link): + os.unlink(link) + os.symlink(os.path.join(os.path.pardir, os.path.pardir, "archive", + "example.org", "{0}{1}.pem".format(kind, ver)), + link) + with open(link, "w") as f: + f.write(kind if value is None else value) + def _write_out_ex_kinds(self): for kind in ALL_FOUR: - where = getattr(self.test_rc, kind) - os.symlink(os.path.join("..", "..", "archive", "example.org", - "{0}12.pem".format(kind)), where) - with open(where, "w") as f: - f.write(kind) - os.unlink(where) - os.symlink(os.path.join("..", "..", "archive", "example.org", - "{0}11.pem".format(kind)), where) - with open(where, "w") as f: - f.write(kind) + self._write_out_kind(kind, 12) + self._write_out_kind(kind, 11) class RenewableCertTests(BaseRenewableCertTest): @@ -204,10 +206,7 @@ class RenewableCertTests(BaseRenewableCertTest): def test_current_target(self): # Relative path logic - os.symlink(os.path.join("..", "..", "archive", "example.org", - "cert17.pem"), self.test_rc.cert) - with open(self.test_rc.cert, "w") as f: - f.write("cert") + self._write_out_kind("cert", 17) self.assertTrue(os.path.samefile(self.test_rc.current_target("cert"), os.path.join(self.tempdir, "archive", "example.org", @@ -225,12 +224,8 @@ class RenewableCertTests(BaseRenewableCertTest): def test_current_version(self): for ver in (1, 5, 10, 20): - os.symlink(os.path.join("..", "..", "archive", "example.org", - "cert{0}.pem".format(ver)), - self.test_rc.cert) - with open(self.test_rc.cert, "w") as f: - f.write("cert") - os.unlink(self.test_rc.cert) + self._write_out_kind("cert", ver) + os.unlink(self.test_rc.cert) os.symlink(os.path.join("..", "..", "archive", "example.org", "cert10.pem"), self.test_rc.cert) self.assertEqual(self.test_rc.current_version("cert"), 10) @@ -241,61 +236,30 @@ class RenewableCertTests(BaseRenewableCertTest): def test_latest_and_next_versions(self): for ver in xrange(1, 6): for kind in ALL_FOUR: - where = getattr(self.test_rc, kind) - if os.path.islink(where): - os.unlink(where) - os.symlink(os.path.join("..", "..", "archive", "example.org", - "{0}{1}.pem".format(kind, ver)), where) - with open(where, "w") as f: - f.write(kind) + self._write_out_kind(kind, ver) self.assertEqual(self.test_rc.latest_common_version(), 5) self.assertEqual(self.test_rc.next_free_version(), 6) # Having one kind of file of a later version doesn't change the # result - os.unlink(self.test_rc.privkey) - os.symlink(os.path.join("..", "..", "archive", "example.org", - "privkey7.pem"), self.test_rc.privkey) - with open(self.test_rc.privkey, "w") as f: - f.write("privkey") + self._write_out_kind("privkey", 7) self.assertEqual(self.test_rc.latest_common_version(), 5) # ... although it does change the next free version self.assertEqual(self.test_rc.next_free_version(), 8) # Nor does having three out of four change the result - os.unlink(self.test_rc.cert) - os.symlink(os.path.join("..", "..", "archive", "example.org", - "cert7.pem"), self.test_rc.cert) - with open(self.test_rc.cert, "w") as f: - f.write("cert") - os.unlink(self.test_rc.fullchain) - os.symlink(os.path.join("..", "..", "archive", "example.org", - "fullchain7.pem"), self.test_rc.fullchain) - with open(self.test_rc.fullchain, "w") as f: - f.write("fullchain") + self._write_out_kind("cert", 7) + self._write_out_kind("fullchain", 7) self.assertEqual(self.test_rc.latest_common_version(), 5) # If we have everything from a much later version, it does change # the result - ver = 17 for kind in ALL_FOUR: - where = getattr(self.test_rc, kind) - if os.path.islink(where): - os.unlink(where) - os.symlink(os.path.join("..", "..", "archive", "example.org", - "{0}{1}.pem".format(kind, ver)), where) - with open(where, "w") as f: - f.write(kind) + self._write_out_kind(kind, 17) self.assertEqual(self.test_rc.latest_common_version(), 17) self.assertEqual(self.test_rc.next_free_version(), 18) def test_update_link_to(self): for ver in xrange(1, 6): for kind in ALL_FOUR: - where = getattr(self.test_rc, kind) - if os.path.islink(where): - os.unlink(where) - os.symlink(os.path.join("..", "..", "archive", "example.org", - "{0}{1}.pem".format(kind, ver)), where) - with open(where, "w") as f: - f.write(kind) + self._write_out_kind(kind, ver) self.assertEqual(ver, self.test_rc.current_version(kind)) # pylint: disable=protected-access self.test_rc._update_link_to("cert", 3) @@ -312,10 +276,7 @@ class RenewableCertTests(BaseRenewableCertTest): "chain3000.pem") def test_version(self): - os.symlink(os.path.join("..", "..", "archive", "example.org", - "cert12.pem"), self.test_rc.cert) - with open(self.test_rc.cert, "w") as f: - f.write("cert") + self._write_out_kind("cert", 12) # TODO: We should probably test that the directory is still the # same, but it's tricky because we can get an absolute # path out when we put a relative path in. @@ -325,13 +286,7 @@ class RenewableCertTests(BaseRenewableCertTest): def test_update_all_links_to_success(self): for ver in xrange(1, 6): for kind in ALL_FOUR: - where = getattr(self.test_rc, kind) - if os.path.islink(where): - os.unlink(where) - os.symlink(os.path.join("..", "..", "archive", "example.org", - "{0}{1}.pem".format(kind, ver)), where) - with open(where, "w") as f: - f.write(kind) + self._write_out_kind(kind, ver) self.assertEqual(ver, self.test_rc.current_version(kind)) self.assertEqual(self.test_rc.latest_common_version(), 5) for ver in xrange(1, 6): @@ -376,13 +331,7 @@ class RenewableCertTests(BaseRenewableCertTest): def test_has_pending_deployment(self): for ver in xrange(1, 6): for kind in ALL_FOUR: - where = getattr(self.test_rc, kind) - if os.path.islink(where): - os.unlink(where) - os.symlink(os.path.join("..", "..", "archive", "example.org", - "{0}{1}.pem".format(kind, ver)), where) - with open(where, "w") as f: - f.write(kind) + self._write_out_kind(kind, ver) self.assertEqual(ver, self.test_rc.current_version(kind)) for ver in xrange(1, 6): self.test_rc.update_all_links_to(ver) @@ -395,24 +344,22 @@ class RenewableCertTests(BaseRenewableCertTest): def test_names(self): # Trying the current version - test_cert = test_util.load_vector("cert-san.pem") - os.symlink(os.path.join("..", "..", "archive", "example.org", - "cert12.pem"), self.test_rc.cert) - with open(self.test_rc.cert, "w") as f: - f.write(test_cert) + self._write_out_kind("cert", 12, test_util.load_vector("cert-san.pem")) self.assertEqual(self.test_rc.names(), ["example.com", "www.example.com"]) # Trying a non-current version - test_cert = test_util.load_vector("cert.pem") - os.unlink(self.test_rc.cert) - os.symlink(os.path.join("..", "..", "archive", "example.org", - "cert15.pem"), self.test_rc.cert) - with open(self.test_rc.cert, "w") as f: - f.write(test_cert) + self._write_out_kind("cert", 15, test_util.load_vector("cert.pem")) self.assertEqual(self.test_rc.names(12), ["example.com", "www.example.com"]) + # Testing common name is listed first + self._write_out_kind( + "cert", 12, test_util.load_vector("cert-5sans.pem")) + self.assertEqual( + self.test_rc.names(12), + ["example.com"] + ["{0}.example.com".format(c) for c in "abcd"]) + # Trying missing cert os.unlink(self.test_rc.cert) self.assertRaises(errors.CertStorageError, self.test_rc.names) @@ -480,13 +427,7 @@ class RenewableCertTests(BaseRenewableCertTest): # No pending deployment for ver in xrange(1, 6): for kind in ALL_FOUR: - where = getattr(self.test_rc, kind) - if os.path.islink(where): - os.unlink(where) - os.symlink(os.path.join("..", "..", "archive", "example.org", - "{0}{1}.pem".format(kind, ver)), where) - with open(where, "w") as f: - f.write(kind) + self._write_out_kind(kind, ver) self.assertFalse(self.test_rc.should_autodeploy()) def test_autorenewal_is_enabled(self): @@ -507,11 +448,7 @@ class RenewableCertTests(BaseRenewableCertTest): self.assertFalse(self.test_rc.should_autorenew()) self.test_rc.configuration["autorenew"] = "1" for kind in ALL_FOUR: - where = getattr(self.test_rc, kind) - os.symlink(os.path.join("..", "..", "archive", "example.org", - "{0}12.pem".format(kind)), where) - with open(where, "w") as f: - f.write(kind) + self._write_out_kind(kind, 12) # Mandatory renewal on the basis of OCSP revocation mock_ocsp.return_value = True self.assertTrue(self.test_rc.should_autorenew()) @@ -525,13 +462,7 @@ class RenewableCertTests(BaseRenewableCertTest): for ver in xrange(1, 6): for kind in ALL_FOUR: - where = getattr(self.test_rc, kind) - if os.path.islink(where): - os.unlink(where) - os.symlink(os.path.join("..", "..", "archive", "example.org", - "{0}{1}.pem".format(kind, ver)), where) - with open(where, "w") as f: - f.write(kind) + self._write_out_kind(kind, ver) self.test_rc.update_all_links_to(3) self.assertEqual( 6, self.test_rc.save_successor(3, "new cert", None, diff --git a/certbot/tests/testdata/cert-5sans.pem b/certbot/tests/testdata/cert-5sans.pem new file mode 100644 index 000000000..5de7cc6cb --- /dev/null +++ b/certbot/tests/testdata/cert-5sans.pem @@ -0,0 +1,16 @@ +-----BEGIN CERTIFICATE----- +MIICkTCCAjugAwIBAgIJAJNbfABWQ8bbMA0GCSqGSIb3DQEBCwUAMHkxCzAJBgNV +BAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRYwFAYDVQQHDA1TYW4gRnJhbmNp +c2NvMScwJQYDVQQKDB5FbGVjdHJvbmljIEZyb250aWVyIEZvdW5kYXRpb24xFDAS +BgNVBAMMC2V4YW1wbGUuY29tMB4XDTE2MDYwOTIzMDEzNloXDTE2MDcwOTIzMDEz +NloweTELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExFjAUBgNVBAcM +DVNhbiBGcmFuY2lzY28xJzAlBgNVBAoMHkVsZWN0cm9uaWMgRnJvbnRpZXIgRm91 +bmRhdGlvbjEUMBIGA1UEAwwLZXhhbXBsZS5jb20wXDANBgkqhkiG9w0BAQEFAANL +ADBIAkEArHVztFHtH92ucFJD/N/HW9AsdRsUuHUBBBDlHwNlRd3fp580rv2+6QWE +30cWgdmJS86ObRz6lUTor4R0T+3C5QIDAQABo4GlMIGiMB0GA1UdDgQWBBQmz8jt +S9eUsuQlA1gkjwTAdNWXijAfBgNVHSMEGDAWgBQmz8jtS9eUsuQlA1gkjwTAdNWX +ijAMBgNVHRMEBTADAQH/MFIGA1UdEQRLMEmCDWEuZXhhbXBsZS5jb22CDWIuZXhh +bXBsZS5jb22CDWMuZXhhbXBsZS5jb22CDWQuZXhhbXBsZS5jb22CC2V4YW1wbGUu +Y29tMA0GCSqGSIb3DQEBCwUAA0EAVXmZxB+IJdgFvY2InOYeytTD1QmouDZRtj/T +H/HIpSdsfO7qr4d/ZprI2IhLRxp2S4BiU5Qc5HUkeADcpNd06A== +-----END CERTIFICATE-----