Convert several SAN handling functions to use cryptography's APIs

This commit is contained in:
Alex Gaynor 2024-12-17 12:39:42 -05:00
parent 619da0432a
commit 0f36d0c1ba
2 changed files with 83 additions and 58 deletions

View file

@ -360,8 +360,8 @@ class GetNamesFromCertTest(unittest.TestCase):
self._call(test_util.load_vector('cert-5sans_512.pem'))
def test_parse_non_cert(self):
with pytest.raises(OpenSSL.crypto.Error):
self._call("hello there")
with pytest.raises(ValueError):
self._call(b"hello there")
class GetNamesFromReqTest(unittest.TestCase):

View file

@ -8,6 +8,7 @@ import datetime
import hashlib
import logging
import re
import typing
from typing import Callable
from typing import List
from typing import Optional
@ -181,33 +182,61 @@ def csr_matches_pubkey(csr: bytes, privkey: bytes) -> bool:
)
def import_csr_file(csrfile: str, data: bytes) -> Tuple[int, util.CSR, List[str]]:
def import_csr_file(
csrfile: str, data: bytes
) -> Tuple[acme_crypto_util.Format, util.CSR, List[str]]:
"""Import a CSR file, which can be either PEM or DER.
:param str csrfile: CSR filename
:param bytes data: contents of the CSR file
:returns: (`crypto.FILETYPE_PEM`,
:returns: (`acme_crypto_util.Format.PEM`,
util.CSR object representing the CSR,
list of domains requested in the CSR)
:rtype: tuple
"""
PEM = crypto.FILETYPE_PEM
load = crypto.load_certificate_request
try:
# Try to parse as DER first, then fall back to PEM.
csr = load(crypto.FILETYPE_ASN1, data)
except crypto.Error:
csr = x509.load_der_x509_csr(data)
except ValueError:
try:
csr = load(PEM, data)
except crypto.Error:
csr = x509.load_pem_x509_csr(data)
except ValueError:
raise errors.Error("Failed to parse CSR file: {0}".format(csrfile))
domains = _get_names_from_loaded_cert_or_req(csr)
domains = _get_names_from_subject_and_extensions(csr.subject, csr.extensions)
# Internally we always use PEM, so re-encode as PEM before returning.
data_pem = crypto.dump_certificate_request(PEM, csr)
return PEM, util.CSR(file=csrfile, data=data_pem, form="pem"), domains
data_pem = csr.public_bytes(serialization.Encoding.PEM)
return (
acme_crypto_util.Format.PEM,
util.CSR(file=csrfile, data=data_pem, form="pem"),
domains,
)
def _get_names_from_subject_and_extensions(
subject: x509.Name, exts: x509.Extensions
) -> List[str]:
# We know these are always `str` because `bytes` is only possible for
# other OIDs.
cns = [
typing.cast(str, c.value)
for c in subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)
]
try:
san_ext = exts.get_extension_for_class(x509.SubjectAlternativeName)
except x509.ExtensionNotFound:
dns_names = []
else:
dns_names = san_ext.value.get_values_for_type(x509.DNSName)
if not cns:
return dns_names
else:
# We only include the first CN, if there are multiple. This matches
# the behavior of the previously implementation using pyOpenSSL.
return [cns[0]] + [d for d in dns_names if d != cns[0]]
def make_key(bits: int = 2048, key_type: str = "rsa",
@ -408,78 +437,74 @@ def pyopenssl_load_certificate(data: bytes) -> Tuple[crypto.X509, int]:
str(error) for error in openssl_errors)))
def _load_cert_or_req(cert_or_req_str: bytes,
load_func: Callable[[int, bytes], Union[crypto.X509, crypto.X509Req]],
typ: int = crypto.FILETYPE_PEM) -> Union[crypto.X509, crypto.X509Req]:
try:
return load_func(typ, cert_or_req_str)
except crypto.Error as err:
logger.debug("", exc_info=True)
logger.error("Encountered error while loading certificate or csr: %s", str(err))
raise
def _get_sans_from_cert_or_req(cert_or_req_str: bytes,
load_func: Callable[[int, bytes], Union[crypto.X509,
crypto.X509Req]],
typ: int = crypto.FILETYPE_PEM) -> List[str]:
# pylint: disable=protected-access
return acme_crypto_util._pyopenssl_cert_or_req_san(_load_cert_or_req(
cert_or_req_str, load_func, typ))
def get_sans_from_cert(cert: bytes, typ: int = crypto.FILETYPE_PEM) -> List[str]:
def get_sans_from_cert(
cert: bytes, typ: Union[acme_crypto_util.Format, int] = acme_crypto_util.Format.PEM
) -> List[str]:
"""Get a list of Subject Alternative Names from a certificate.
:param str cert: Certificate (encoded).
:param typ: `crypto.FILETYPE_PEM` or `crypto.FILETYPE_ASN1`
:param Format typ: Which format the `cert` bytes are in.
:returns: A list of Subject Alternative Names.
:rtype: list
"""
return _get_sans_from_cert_or_req(
cert, crypto.load_certificate, typ)
typ = acme_crypto_util.Format(typ)
if typ == acme_crypto_util.Format.PEM:
x509_cert = x509.load_pem_x509_certificate(cert)
else:
assert typ == acme_crypto_util.Format.DER
x509_cert = x509.load_der_x509_certificate(cert)
try:
san_ext = x509_cert.extensions.get_extension_for_class(
x509.SubjectAlternativeName
)
except x509.ExtensionNotFound:
return []
return san_ext.value.get_values_for_type(x509.DNSName)
def _get_names_from_cert_or_req(cert_or_req: bytes,
load_func: Callable[[int, bytes], Union[crypto.X509,
crypto.X509Req]],
typ: int) -> List[str]:
loaded_cert_or_req = _load_cert_or_req(cert_or_req, load_func, typ)
return _get_names_from_loaded_cert_or_req(loaded_cert_or_req)
def _get_names_from_loaded_cert_or_req(loaded_cert_or_req: Union[crypto.X509, crypto.X509Req]
) -> List[str]:
# pylint: disable=protected-access
return acme_crypto_util._pyopenssl_cert_or_req_all_names(loaded_cert_or_req)
def get_names_from_cert(cert: bytes, typ: int = crypto.FILETYPE_PEM) -> List[str]:
def get_names_from_cert(
cert: bytes, typ: Union[acme_crypto_util.Format, int] = acme_crypto_util.Format.PEM
) -> List[str]:
"""Get a list of domains from a cert, including the CN if it is set.
:param str cert: Certificate (encoded).
:param typ: `crypto.FILETYPE_PEM` or `crypto.FILETYPE_ASN1`
:param Format typ: Which format the `cert` bytes are in.
:returns: A list of domain names.
:rtype: list
"""
return _get_names_from_cert_or_req(
cert, crypto.load_certificate, typ)
typ = acme_crypto_util.Format(typ)
if typ == acme_crypto_util.Format.PEM:
x509_cert = x509.load_pem_x509_certificate(cert)
else:
assert typ == acme_crypto_util.Format.DER
x509_cert = x509.load_der_x509_certificate(cert)
return _get_names_from_subject_and_extensions(x509_cert.subject, x509_cert.extensions)
def get_names_from_req(csr: bytes, typ: int = crypto.FILETYPE_PEM) -> List[str]:
def get_names_from_req(
csr: bytes, typ: Union[acme_crypto_util.Format, int] = acme_crypto_util.Format.PEM
) -> List[str]:
"""Get a list of domains from a CSR, including the CN if it is set.
:param str csr: CSR (encoded).
:param typ: `crypto.FILETYPE_PEM` or `crypto.FILETYPE_ASN1`
:param acme_crypto_util.Format typ: Which format the `csr` bytes are in.
:returns: A list of domain names.
:rtype: list
"""
return _get_names_from_cert_or_req(csr, crypto.load_certificate_request, typ)
typ = acme_crypto_util.Format(typ)
if typ == acme_crypto_util.Format.PEM:
x509_req = x509.load_pem_x509_csr(csr)
else:
assert typ == acme_crypto_util.Format.DER
x509_req = x509.load_der_x509_csr(csr)
return _get_names_from_subject_and_extensions(x509_req.subject, x509_req.extensions)
def dump_pyopenssl_chain(