diff --git a/bin/tests/system/isctest/check.py b/bin/tests/system/isctest/check.py index b705bfee21..7ca7779a3d 100644 --- a/bin/tests/system/isctest/check.py +++ b/bin/tests/system/isctest/check.py @@ -12,15 +12,16 @@ import difflib import shutil import os -from typing import Optional +from typing import cast, List, Optional +import dns.edns import dns.flags import dns.message import dns.rcode import dns.zone import isctest.log -from isctest.compat import dns_rcode +from isctest.compat import dns_rcode, EDECode, EDEOption def rcode(message: dns.message.Message, expected_rcode) -> None: @@ -67,6 +68,54 @@ def noraflag(message: dns.message.Message) -> None: assert (message.flags & dns.flags.RA) == 0, str(message) +def _extract_ede_options( + message: dns.message.Message, +) -> List[EDEOption]: + """Extract EDE options from the DNS message.""" + return cast( + List[EDEOption], + [ + option + for option in message.options + if option.otype == dns.edns.OptionType.EDE + ], + ) + + +def noede(message: dns.message.Message) -> None: + """Check that message contains no EDE option.""" + if not hasattr(dns.edns, "EDECode"): + # dnspython<2.2.0 doesn't support EDE, skip check + return + + ede_options = _extract_ede_options(message) + assert not ede_options, f"unexpected EDE options {ede_options} in {message}" + + +def ede( + message: dns.message.Message, code: EDECode, text: Optional[str] = None +) -> None: + """Check if message contains expected EDE code (and its text).""" + if not hasattr(dns.edns, "EDECode"): + # dnspython<2.2.0 doesn't support EDE, skip check + return + + msg_opts = _extract_ede_options(message) + matching_opts = [opt for opt in msg_opts if opt.code == code] + + assert matching_opts, f"missing EDE code {code} in {message}" + + if text is None: + return + + # check at least one matching EDE option has the required text + for opt in matching_opts: + if opt.text == text: + return + opt_str = ", ".join([opt.to_text() for opt in matching_opts]) + assert False, f'EDE text "{text}" not found in [{opt_str}]' + + def section_equal(first_section: list, second_section: list) -> None: for rrset in first_section: assert ( diff --git a/bin/tests/system/isctest/compat.py b/bin/tests/system/isctest/compat.py index 5580f1f4c5..3dc5810745 100644 --- a/bin/tests/system/isctest/compat.py +++ b/bin/tests/system/isctest/compat.py @@ -9,8 +9,9 @@ # See the COPYRIGHT file distributed with this work for additional # information regarding copyright ownership. -from typing import Any +from typing import Any, TYPE_CHECKING +import dns.edns import dns.rcode # compatiblity with dnspython<2.0.0 @@ -22,3 +23,34 @@ except AttributeError: # In dnspython<2.0.0, selected rcodes are available as integers directly # from dns.rcode dns_rcode = dns.rcode + + +if TYPE_CHECKING: + EDECode = dns.edns.EDECode + EDEOption = dns.edns.EDEOption +else: + try: # compatiblity with dnspython<2.2.0 + EDECode = dns.edns.EDECode + except AttributeError: + # In dnspython<2.2.0, the dns.edns.EDECode doesn't exist. + # + # The primary use-case is for us to use existing EDECode objects from the + # class, e.g. EDECode.FILTERED. To mimick this behavior, use a string + # factory that just turns the attribute name into a string. + # + # The used compatibility hack doesn't really matter (as long as EDECode.xxx + # doesn't raise exception), as with dnspython versions prior to 2.2.0, any + # EDE checking will be skipped anyway. + class _CompatEDECode: + def __getattr__(self, name: str) -> str: + return name + + EDECode = _CompatEDECode() + try: + EDEOption = dns.edns.EDEOption + except AttributeError: + # In dnspython<2.2.0, the dns.edns.EDEOption doesn't exist, so we stub it to be + # able to use it in type annotations. + class EDEOption: + def __new__(cls, *args, **kwargs): + raise RuntimeError("Using EDEOption requires dnspython>=2.2.0")