diff --git a/bin/tests/system/chain/ans3/ans.py b/bin/tests/system/chain/ans3/ans.py index 4a87dfc89c..b61cd9d79a 100755 --- a/bin/tests/system/chain/ans3/ans.py +++ b/bin/tests/system/chain/ans3/ans.py @@ -14,6 +14,7 @@ information regarding copyright ownership. from typing import AsyncGenerator import dns.name +import dns.rcode import dns.rdataclass import dns.rdatatype import dns.rrset @@ -30,7 +31,7 @@ from isctest.asyncserver import ( try: dns_namerelation_equal = dns.name.NameRelation.EQUAL dns_namerelation_subdomain = dns.name.NameRelation.SUBDOMAIN -except AttributeError: # dnspython < 2.0.0 compat +except AttributeError: # dnspython < 2.3.0 compat dns_namerelation_equal = dns.name.NAMERELN_EQUAL # type: ignore dns_namerelation_subdomain = dns.name.NAMERELN_SUBDOMAIN # type: ignore @@ -69,7 +70,7 @@ class CnameThenDnameHandler(DomainHandler): dname_rrset = get_dname_rrset_at_name(qctx.zone, dname_owner) qctx.response.answer.append(dname_rrset) - yield DnsResponseSend(qctx.response, authoritative=True) + yield DnsResponseSend(qctx.response) class Cve202125215(DomainHandler): @@ -108,13 +109,12 @@ class Cve202125215(DomainHandler): ) qctx.response.answer.append(cname_rrset) - yield DnsResponseSend(qctx.response, authoritative=True) + yield DnsResponseSend(qctx.response) def main() -> None: - server = AsyncDnsServer(acknowledge_manual_dname_handling=True) - server.install_response_handler(CnameThenDnameHandler()) - server.install_response_handler(Cve202125215()) + server = AsyncDnsServer(acknowledge_manual_dname_handling=True, default_aa=True) + server.install_response_handlers([CnameThenDnameHandler(), Cve202125215()]) server.run() diff --git a/bin/tests/system/chain/ans4/ans.py b/bin/tests/system/chain/ans4/ans.py index 3e042ea58c..2f0d8c3352 100755 --- a/bin/tests/system/chain/ans4/ans.py +++ b/bin/tests/system/chain/ans4/ans.py @@ -19,8 +19,8 @@ import abc import logging import re +import dns.name import dns.rcode -import dns.rdata import dns.rdataclass import dns.rdatatype import dns.rrset @@ -34,11 +34,6 @@ from isctest.asyncserver import ( ResponseAction, ) -try: - RdataType = dns.rdatatype.RdataType -except AttributeError: # dnspython < 2.0.0 compat - RdataType = int # type: ignore - class ChainNameGenerator: """ @@ -105,13 +100,13 @@ class RecordGenerator(abc.ABC): @classmethod def create_rrset( - cls, owner: dns.name.Name, rrtype: RdataType, rdata: str + cls, owner: dns.name.Name, rrtype: dns.rdatatype.RdataType, rdata: str ) -> dns.rrset.RRset: return dns.rrset.from_text(owner, 86400, dns.rdataclass.IN, rrtype, rdata) @classmethod def create_rrset_signature( - cls, owner: dns.name.Name, rrtype: RdataType + cls, owner: dns.name.Name, rrtype: dns.rdatatype.RdataType ) -> dns.rrset.RRset: covers = dns.rdatatype.to_text(rrtype) ttl = "86400" @@ -443,9 +438,8 @@ class ChainResponseHandler(DomainHandler): for rrset in self._additional_rrsets: qctx.response.additional.append(rrset) - qctx.response.set_rcode(dns.rcode.NOERROR) qctx.response.use_edns() - yield DnsResponseSend(qctx.response, authoritative=True) + yield DnsResponseSend(qctx.response) def _non_chain_answer(self, qctx: QueryContext) -> List[dns.rrset.RRset]: owner = qctx.qname @@ -473,7 +467,10 @@ class ChainResponseHandler(DomainHandler): def main() -> None: - server = ControllableAsyncDnsServer(commands=[ChainSetupCommand]) + server = ControllableAsyncDnsServer( + default_aa=True, default_rcode=dns.rcode.NOERROR + ) + server.install_control_command(ChainSetupCommand()) server.run() diff --git a/bin/tests/system/cookie/cookie_ans.py b/bin/tests/system/cookie/cookie_ans.py index bd2782d0d6..50b06f2c16 100644 --- a/bin/tests/system/cookie/cookie_ans.py +++ b/bin/tests/system/cookie/cookie_ans.py @@ -11,7 +11,11 @@ from typing import AsyncGenerator -import dns +import dns.edns +import dns.name +import dns.rcode +import dns.rdatatype +import dns.rrset import dns.tsigkeyring from isctest.asyncserver import ( @@ -33,16 +37,6 @@ KEYRING = dns.tsigkeyring.from_text( ) -def _reparse_with_keyring(qctx: QueryContext) -> None: - """ - `isctest.asyncserver` doesn't support TSIG signing and validation properly - and hacks around it. However, here we need to be able to sign responses with - TSIG, so we reparse the query and recreate the response stub here. - """ - qctx.query = dns.message.from_wire(qctx.query.to_wire(), keyring=KEYRING) - qctx.response = dns.message.make_response(qctx.query) - - def _first_label(qctx: QueryContext) -> str: return qctx.qname.labels[0].decode("ascii") @@ -68,7 +62,7 @@ def _tld(qctx: QueryContext) -> dns.name.Name: def _soa(qctx: QueryContext) -> dns.rrset.RRset: return dns.rrset.from_text( - _tld(qctx), 2, dns.rdataclass.IN, dns.rdatatype.SOA, ". . 0 0 0 0 2" + _tld(qctx), 2, qctx.qclass, dns.rdatatype.SOA, ". . 0 0 0 0 2" ) @@ -80,21 +74,19 @@ def _ns(qctx: QueryContext) -> dns.rrset.RRset: return dns.rrset.from_text( qctx.qname, 1, - dns.rdataclass.IN, + qctx.qclass, dns.rdatatype.NS, _ns_name(qctx).to_text(), ) def _legit_a(qctx: QueryContext) -> dns.rrset.RRset: - return dns.rrset.from_text( - qctx.qname, 1, dns.rdataclass.IN, dns.rdatatype.A, "10.53.0.9" - ) + return dns.rrset.from_text(qctx.qname, 1, qctx.qclass, dns.rdatatype.A, "10.53.0.9") def _spoofed_a(qctx: QueryContext) -> dns.rrset.RRset: return dns.rrset.from_text( - qctx.qname, 1, dns.rdataclass.IN, dns.rdatatype.A, "10.53.0.10" + qctx.qname, 1, qctx.qclass, dns.rdatatype.A, "10.53.0.10" ) @@ -110,14 +102,13 @@ class NsHandler(_SpoofableHandler): async def get_responses( self, qctx: QueryContext ) -> AsyncGenerator[DnsResponseSend, None]: - _reparse_with_keyring(qctx) _add_cookie(qctx) qctx.response.answer.append(_ns(qctx)) if self.evil_server: qctx.response.authority.append(_spoofed_a(qctx)) else: qctx.response.authority.append(_legit_a(qctx)) - yield DnsResponseSend(qctx.response, authoritative=True) + yield DnsResponseSend(qctx.response) class GlueHandler(_SpoofableHandler): @@ -127,13 +118,12 @@ class GlueHandler(_SpoofableHandler): async def get_responses( self, qctx: QueryContext ) -> AsyncGenerator[DnsResponseSend, None]: - _reparse_with_keyring(qctx) _add_cookie(qctx) if self.evil_server: qctx.response.answer.append(_spoofed_a(qctx)) else: qctx.response.answer.append(_legit_a(qctx)) - yield DnsResponseSend(qctx.response, authoritative=True) + yield DnsResponseSend(qctx.response) class TcpAHandler(ResponseHandler): @@ -143,11 +133,10 @@ class TcpAHandler(ResponseHandler): async def get_responses( self, qctx: QueryContext ) -> AsyncGenerator[DnsResponseSend, None]: - _reparse_with_keyring(qctx) if _first_label(qctx) != "nocookie": _add_cookie(qctx) qctx.response.answer.append(_legit_a(qctx)) - yield DnsResponseSend(qctx.response, authoritative=True) + yield DnsResponseSend(qctx.response) class WithtsigUdpAHandler(ResponseHandler): @@ -161,16 +150,15 @@ class WithtsigUdpAHandler(ResponseHandler): async def get_responses( self, qctx: QueryContext ) -> AsyncGenerator[DnsResponseSend, None]: - _reparse_with_keyring(qctx) qctx.response.answer.append(_legit_a(qctx)) qctx.response.answer.append(_spoofed_a(qctx)) qctx.response.use_tsig(keyring=KEYRING, keyname="fake") - yield DnsResponseSend(qctx.response, authoritative=True) + yield DnsResponseSend(qctx.response) - _reparse_with_keyring(qctx) + qctx.prepare_new_response() _add_cookie(qctx) qctx.response.answer.append(_legit_a(qctx)) - yield DnsResponseSend(qctx.response, authoritative=True) + yield DnsResponseSend(qctx.response) class UdpAHandler(ResponseHandler): @@ -180,35 +168,39 @@ class UdpAHandler(ResponseHandler): async def get_responses( self, qctx: QueryContext ) -> AsyncGenerator[DnsResponseSend, None]: - _reparse_with_keyring(qctx) qctx.response.answer.append(_legit_a(qctx)) if _first_label(qctx) not in ("nocookie", "tcponly"): _add_cookie(qctx) else: qctx.response.answer.append(_spoofed_a(qctx)) - yield DnsResponseSend(qctx.response, authoritative=True) + yield DnsResponseSend(qctx.response) class FallbackHandler(ResponseHandler): async def get_responses( self, qctx: QueryContext ) -> AsyncGenerator[DnsResponseSend, None]: - _reparse_with_keyring(qctx) _add_cookie(qctx) if qctx.qtype == dns.rdatatype.SOA: qctx.response.answer.append(_soa(qctx)) else: qctx.response.authority.append(_soa(qctx)) - yield DnsResponseSend(qctx.response, authoritative=True) + yield DnsResponseSend(qctx.response) def cookie_server(evil: bool) -> AsyncDnsServer: - server = AsyncDnsServer(acknowledge_tsig_dnspython_hacks=True) - server.install_response_handler(NsHandler(evil)) - server.install_response_handler(GlueHandler(evil)) - server.install_response_handler(TcpAHandler()) - server.install_response_handler(WithtsigUdpAHandler()) - server.install_response_handler(UdpAHandler()) - server.install_response_handler(FallbackHandler()) + server = AsyncDnsServer( + keyring=KEYRING, default_aa=True, default_rcode=dns.rcode.NOERROR + ) + server.install_response_handlers( + [ + NsHandler(evil), + GlueHandler(evil), + TcpAHandler(), + WithtsigUdpAHandler(), + UdpAHandler(), + FallbackHandler(), + ] + ) return server diff --git a/bin/tests/system/dispatch/ans3/ans.py b/bin/tests/system/dispatch/ans3/ans.py index 653232f991..97b1f2cf1f 100644 --- a/bin/tests/system/dispatch/ans3/ans.py +++ b/bin/tests/system/dispatch/ans3/ans.py @@ -11,7 +11,8 @@ from typing import AsyncGenerator -import dns +import dns.flags +import dns.rcode from isctest.asyncserver import ( AsyncDnsServer, @@ -29,13 +30,12 @@ class TruncateOnUdpHandler(ResponseHandler): self, qctx: QueryContext ) -> AsyncGenerator[ResponseAction, None]: assert qctx.protocol == DnsProtocol.UDP, "This server only supports UDP" - qctx.response.set_rcode(dns.rcode.NOERROR) qctx.response.flags |= dns.flags.TC yield DnsResponseSend(qctx.response) def main() -> None: - server = AsyncDnsServer() + server = AsyncDnsServer(default_rcode=dns.rcode.NOERROR) server.install_connection_handler(ConnectionReset(delay=1.0)) server.install_response_handler(TruncateOnUdpHandler()) server.run() diff --git a/bin/tests/system/dnssec/ans10/ans.py b/bin/tests/system/dnssec/ans10/ans.py index 7c0798f2f8..f69d5ebe14 100644 --- a/bin/tests/system/dnssec/ans10/ans.py +++ b/bin/tests/system/dnssec/ans10/ans.py @@ -11,7 +11,8 @@ from typing import AsyncGenerator -import dns +import dns.rdatatype +import dns.rrset from isctest.asyncserver import ( AsyncDnsServer, @@ -33,7 +34,7 @@ class AddRrsigToAHandler(ResponseHandler): "gB+eISXAhSPZU2i/II0W9ZUhC2SCIrb94mlNvP5092WAeXxqN/vG43/1nmDly2Qs7y5VCjSMOGn85bnaMoAc7w==" ) rrsig_rrset = dns.rrset.from_text( - qctx.qname, 1, dns.rdataclass.IN, dns.rdatatype.RRSIG, rrsig + qctx.qname, 1, qctx.qclass, dns.rdatatype.RRSIG, rrsig ) qctx.response.answer.append(rrsig_rrset) yield DnsResponseSend(qctx.response) @@ -48,7 +49,7 @@ class AddNsecToTxtHandler(ResponseHandler): ) -> AsyncGenerator[DnsResponseSend, None]: nsec = f"{qctx.qname.to_text()} A NS SOA RRSIG NSEC" nsec_rrset = dns.rrset.from_text( - qctx.qname, 1, dns.rdataclass.IN, dns.rdatatype.NSEC, nsec + qctx.qname, 1, qctx.qclass, dns.rdatatype.NSEC, nsec ) qctx.response.authority.append(nsec_rrset) yield DnsResponseSend(qctx.response) @@ -56,8 +57,7 @@ class AddNsecToTxtHandler(ResponseHandler): def main() -> None: server = AsyncDnsServer() - server.install_response_handler(AddRrsigToAHandler()) - server.install_response_handler(AddNsecToTxtHandler()) + server.install_response_handlers([AddRrsigToAHandler(), AddNsecToTxtHandler()]) server.run() diff --git a/bin/tests/system/fetchlimit/ans4/ans.py b/bin/tests/system/fetchlimit/ans4/ans.py index 34891fa310..cd7602366b 100644 --- a/bin/tests/system/fetchlimit/ans4/ans.py +++ b/bin/tests/system/fetchlimit/ans4/ans.py @@ -13,7 +13,9 @@ information regarding copyright ownership. from typing import AsyncGenerator -import dns +import dns.rcode +import dns.rdatatype +import dns.rrset from isctest.asyncserver import ( ControllableAsyncDnsServer, @@ -33,13 +35,15 @@ class MaybeDelayedAddressAnswerHandler(ResponseHandler): rrset = dns.rrset.from_text(qctx.qname, 300, qctx.qclass, qctx.qtype, addr) qctx.response.answer.append(rrset) - qctx.response.set_rcode(dns.rcode.NOERROR) delay = 0.05 if qctx.qname.labels[0].startswith(b"latency") else 0.00 - yield DnsResponseSend(qctx.response, delay=delay, authoritative=True) + yield DnsResponseSend(qctx.response, delay=delay) def main() -> None: - server = ControllableAsyncDnsServer([ToggleResponsesCommand]) + server = ControllableAsyncDnsServer( + default_aa=True, default_rcode=dns.rcode.NOERROR + ) + server.install_control_command(ToggleResponsesCommand()) server.install_response_handler(MaybeDelayedAddressAnswerHandler()) server.run() diff --git a/bin/tests/system/forward/ans11/ans.py b/bin/tests/system/forward/ans11/ans.py index 8d0b3e9b33..b5b590aabf 100644 --- a/bin/tests/system/forward/ans11/ans.py +++ b/bin/tests/system/forward/ans11/ans.py @@ -14,6 +14,7 @@ information regarding copyright ownership. from typing import AsyncGenerator import dns.rdatatype +import dns.rrset from isctest.asyncserver import ( ControllableAsyncDnsServer, @@ -49,7 +50,8 @@ class ExtraAnswersHandler(DomainHandler): def main() -> None: - server = ControllableAsyncDnsServer(commands=[ToggleResponsesCommand]) + server = ControllableAsyncDnsServer() + server.install_control_command(ToggleResponsesCommand()) server.install_response_handler(ExtraAnswersHandler()) server.run() diff --git a/bin/tests/system/forward/ans6/ans.py b/bin/tests/system/forward/ans6/ans.py index f63cdcd4d5..fdcbe7d392 100644 --- a/bin/tests/system/forward/ans6/ans.py +++ b/bin/tests/system/forward/ans6/ans.py @@ -13,7 +13,10 @@ information regarding copyright ownership. from typing import AsyncGenerator -import dns +import dns.name +import dns.rcode +import dns.rdatatype +import dns.rrset from isctest.asyncserver import ( ControllableAsyncDnsServer, @@ -60,7 +63,6 @@ class ChaseDsHandler(ResponseHandler): response_rdata = ". . 0 0 0 0 0" response_section = qctx.response.authority - qctx.response.set_rcode(dns.rcode.NOERROR) qctx.response.use_edns(None) response_rrset = dns.rrset.from_text( @@ -68,11 +70,14 @@ class ChaseDsHandler(ResponseHandler): ) response_section.append(response_rrset) - yield DnsResponseSend(qctx.response, authoritative=True) + yield DnsResponseSend(qctx.response) def main() -> None: - server = ControllableAsyncDnsServer([ToggleResponsesCommand]) + server = ControllableAsyncDnsServer( + default_rcode=dns.rcode.NOERROR, default_aa=True + ) + server.install_control_command(ToggleResponsesCommand()) server.install_response_handler(ChaseDsHandler()) server.run() diff --git a/bin/tests/system/isctest/asyncserver.py b/bin/tests/system/isctest/asyncserver.py index c91f63a123..98fec6b663 100644 --- a/bin/tests/system/isctest/asyncserver.py +++ b/bin/tests/system/isctest/asyncserver.py @@ -21,7 +21,6 @@ from typing import ( List, Optional, Tuple, - Type, Union, cast, ) @@ -29,6 +28,7 @@ from typing import ( import abc import asyncio import contextlib +import copy import enum import functools import logging @@ -39,25 +39,21 @@ import signal import struct import sys +import dns.exception import dns.flags import dns.message import dns.name import dns.node import dns.rcode +import dns.rdata import dns.rdataclass +import dns.rdataset import dns.rdatatype import dns.rrset import dns.tsig import dns.version import dns.zone -try: - RdataType = dns.rdatatype.RdataType - RdataClass = dns.rdataclass.RdataClass -except AttributeError: # dnspython < 2.0.0 compat - RdataType = int # type: ignore - RdataClass = int # type: ignore - _UdpHandler = Callable[ [bytes, Tuple[str, int], asyncio.DatagramTransport], Coroutine[Any, Any, None] @@ -274,11 +270,17 @@ class QueryContext: response: dns.message.Message peer: Peer protocol: DnsProtocol - zone: Optional[dns.zone.Zone] = None - soa: Optional[dns.rrset.RRset] = None - node: Optional[dns.node.Node] = None - answer: Optional[dns.rdataset.Rdataset] = None - alias: Optional[dns.name.Name] = None + zone: Optional[dns.zone.Zone] = field(default=None, init=False) + soa: Optional[dns.rrset.RRset] = field(default=None, init=False) + node: Optional[dns.node.Node] = field(default=None, init=False) + answer: Optional[dns.rdataset.Rdataset] = field(default=None, init=False) + alias: Optional[dns.name.Name] = field(default=None, init=False) + _initialized_response: Optional[dns.message.Message] = field( + default=None, init=False + ) + _initialized_response_with_zone_data: Optional[dns.message.Message] = field( + default=None, init=False + ) @property def qname(self) -> dns.name.Name: @@ -289,13 +291,30 @@ class QueryContext: return self.alias or self.qname @property - def qclass(self) -> RdataClass: + def qclass(self) -> dns.rdataclass.RdataClass: return self.query.question[0].rdclass @property - def qtype(self) -> RdataType: + def qtype(self) -> dns.rdatatype.RdataType: return self.query.question[0].rdtype + def prepare_new_response( + self, /, with_zone_data: bool = True + ) -> dns.message.Message: + if with_zone_data: + assert self._initialized_response_with_zone_data + self.response = copy.deepcopy(self._initialized_response_with_zone_data) + else: + assert self._initialized_response + self.response = copy.deepcopy(self._initialized_response) + return self.response + + def save_initialized_response(self, /, with_zone_data: bool) -> None: + if with_zone_data: + self._initialized_response_with_zone_data = copy.deepcopy(self.response) + else: + self._initialized_response = copy.deepcopy(self.response) + @dataclass class ResponseAction(abc.ABC): @@ -756,6 +775,10 @@ class _DnsMessageWithTsigDisabled(dns.message.Message): return super().to_wire(*args, **kwargs) +class _NoKeyringType: + pass + + class AsyncDnsServer(AsyncServer): """ DNS server which responds to queries based on zone data and/or custom @@ -774,9 +797,13 @@ class AsyncDnsServer(AsyncServer): def __init__( self, + /, default_rcode: dns.rcode.Rcode = dns.rcode.REFUSED, + default_aa: bool = True, + keyring: Union[ + Dict[dns.name.Name, dns.tsig.Key], None, _NoKeyringType + ] = _NoKeyringType(), acknowledge_manual_dname_handling: bool = False, - acknowledge_tsig_dnspython_hacks: bool = False, ) -> None: super().__init__(self._handle_udp, self._handle_tcp, "ans.pid") @@ -784,8 +811,9 @@ class AsyncDnsServer(AsyncServer): self._connection_handler: Optional[ConnectionHandler] = None self._response_handlers: List[ResponseHandler] = [] self._default_rcode = default_rcode + self._default_aa = default_aa + self._keyring = keyring self._acknowledge_manual_dname_handling = acknowledge_manual_dname_handling - self._acknowledge_tsig_dnspython_hacks = acknowledge_tsig_dnspython_hacks self._load_zones() @@ -808,6 +836,10 @@ class AsyncDnsServer(AsyncServer): else: self._response_handlers.append(handler) + def install_response_handlers(self, handlers: List[ResponseHandler]) -> None: + for handler in handlers: + self.install_response_handler(handler) + def uninstall_response_handler(self, handler: ResponseHandler) -> None: """ Remove the specified handler from the list of response handlers. @@ -1060,10 +1092,7 @@ class AsyncDnsServer(AsyncServer): Yield wire data to send as a response over the established transport. """ try: - query = dns.message.from_wire(wire) - except dns.message.UnknownTSIGKey: - self._abort_if_tsig_signed_query_received_unless_acknowledged() - query = _DnsMessageWithTsigDisabled.from_wire(wire) + query = self._parse_message(wire) except dns.exception.DNSException as exc: logging.error("Invalid query from %s (%s): %s", peer, wire.hex(), exc) return @@ -1082,18 +1111,25 @@ class AsyncDnsServer(AsyncServer): response_length = struct.pack("!H", len(response)) yield response_length + response - def _abort_if_tsig_signed_query_received_unless_acknowledged(self) -> None: - if self._acknowledge_tsig_dnspython_hacks: - return - - error = "TSIG-signed query received; " - error += "due to a bug in dnspython, this requires some hacking around; " - error += "you may experience unexpected behavior when dealing with TSIG; " - error += "TSIG validation is disabled, so any TSIG handling must be done " - error += "manually; pass `acknowledge_tsig_dnspython_hacks=True` to the " - error += "AsyncDnsServer constructor to acknowledge this and continue." - - raise ValueError(error) + def _parse_message(self, wire: bytes) -> dns.message.Message: + try: + if isinstance(self._keyring, _NoKeyringType): + keyring = None + else: + keyring = self._keyring + return dns.message.from_wire(wire, keyring=keyring) + except dns.message.UnknownTSIGKey as exc: + if isinstance(self._keyring, _NoKeyringType): + error = "TSIG-signed query received but no `keyring` was provided; " + error += "either provide a keyring (in which case the server will " + error += "ignore any TSIG-invalid queries), or set `keyring=None` " + error += "explicitly to disable TSIG validation altogether. " + error += "This requires some hacking around a dnspython bug, " + error += "so there may be unexpected side effects." + raise ValueError(error) from exc + if self._keyring is None: + return _DnsMessageWithTsigDisabled.from_wire(wire) + raise async def _prepare_responses( self, qctx: QueryContext @@ -1102,8 +1138,12 @@ class AsyncDnsServer(AsyncServer): Yield response(s) either from response handlers or zone data. """ qctx.response.set_rcode(self._default_rcode) + if self._default_aa: + qctx.response.flags |= dns.flags.AA + qctx.save_initialized_response(with_zone_data=False) self._prepare_response_from_zone_data(qctx) + qctx.save_initialized_response(with_zone_data=True) response_handled = False async for action in self._run_response_handlers(qctx): @@ -1281,22 +1321,29 @@ class ControllableAsyncDnsServer(AsyncDnsServer): _CONTROL_DOMAIN = "_control." - def __init__(self, commands: List[Type["ControlCommand"]]): - super().__init__() - self._control_domain = dns.name.from_text(self._CONTROL_DOMAIN) - self._commands: Dict[dns.name.Name, "ControlCommand"] = {} - for command_class in commands: - command = command_class() - command_subdomain = dns.name.Name([command.control_subdomain]) - control_subdomain = command_subdomain.concatenate(self._control_domain) - try: - existing_command = self._commands[control_subdomain] - except KeyError: - self._commands[control_subdomain] = command - else: - raise RuntimeError( - f"{control_subdomain} already handled by {existing_command}" - ) + @functools.cached_property + def _control_domain(self) -> dns.name.Name: + return dns.name.from_text(self._CONTROL_DOMAIN) + + @functools.cached_property + def _commands(self) -> Dict[dns.name.Name, "ControlCommand"]: + return {} + + def install_control_commands(self, commands: List["ControlCommand"]) -> None: + for command in commands: + self.install_control_command(command) + + def install_control_command(self, command: "ControlCommand") -> None: + command_subdomain = dns.name.Name([command.control_subdomain]) + control_subdomain = command_subdomain.concatenate(self._control_domain) + try: + existing_command = self._commands[control_subdomain] + except KeyError: + self._commands[control_subdomain] = command + else: + raise RuntimeError( + f"{control_subdomain} already handled by {existing_command}" + ) async def _prepare_responses( self, qctx: QueryContext diff --git a/bin/tests/system/qmin/ans2/ans.py b/bin/tests/system/qmin/ans2/ans.py index 18f077781e..5625a611fb 100644 --- a/bin/tests/system/qmin/ans2/ans.py +++ b/bin/tests/system/qmin/ans2/ans.py @@ -16,8 +16,8 @@ from typing import AsyncGenerator import dns.message import dns.name import dns.rcode -import dns.rdataclass import dns.rdatatype +import dns.rrset from isctest.asyncserver import ( AsyncDnsServer, @@ -63,12 +63,8 @@ def send_delegation( ADDITIONAL section. """ ns_name = "ns." + zone_cut.to_text() - ns_rrset = dns.rrset.from_text( - zone_cut, 2, dns.rdataclass.IN, dns.rdatatype.NS, ns_name - ) - a_rrset = dns.rrset.from_text( - ns_name, 2, dns.rdataclass.IN, dns.rdatatype.A, target_addr - ) + ns_rrset = dns.rrset.from_text(zone_cut, 2, qctx.qclass, dns.rdatatype.NS, ns_name) + a_rrset = dns.rrset.from_text(ns_name, 2, qctx.qclass, dns.rdatatype.A, target_addr) response = dns.message.make_response(qctx.query) response.set_rcode(dns.rcode.NOERROR) @@ -103,11 +99,15 @@ class StaleHandler(DomainHandler): def main() -> None: server = AsyncDnsServer() - server.install_response_handler(QueryLogger()) - server.install_response_handler(BadHandler()) - server.install_response_handler(UglyHandler()) - server.install_response_handler(SlowHandler()) - server.install_response_handler(StaleHandler()) + server.install_response_handlers( + [ + QueryLogger(), + BadHandler(), + UglyHandler(), + SlowHandler(), + StaleHandler(), + ] + ) server.run() diff --git a/bin/tests/system/qmin/ans3/ans.py b/bin/tests/system/qmin/ans3/ans.py index 6547dd2f9b..101ea2a14f 100644 --- a/bin/tests/system/qmin/ans3/ans.py +++ b/bin/tests/system/qmin/ans3/ans.py @@ -39,10 +39,14 @@ class ZoopBoingSlowHandler(DelayedResponseHandler): def main() -> None: server = AsyncDnsServer() - server.install_response_handler(QueryLogger()) - server.install_response_handler(ZoopBoingBadHandler()) - server.install_response_handler(ZoopBoingUglyHandler()) - server.install_response_handler(ZoopBoingSlowHandler()) + server.install_response_handlers( + [ + QueryLogger(), + ZoopBoingBadHandler(), + ZoopBoingUglyHandler(), + ZoopBoingSlowHandler(), + ] + ) server.run() diff --git a/bin/tests/system/qmin/ans4/ans.py b/bin/tests/system/qmin/ans4/ans.py index ebe500bad6..74b9d9fa80 100644 --- a/bin/tests/system/qmin/ans4/ans.py +++ b/bin/tests/system/qmin/ans4/ans.py @@ -14,6 +14,7 @@ information regarding copyright ownership. from typing import AsyncGenerator import dns.rcode +import dns.rdatatype from isctest.asyncserver import ( AsyncDnsServer, @@ -85,11 +86,15 @@ class IckyPtangZoopBoingSlowHandler(DelayedResponseHandler): def main() -> None: server = AsyncDnsServer() - server.install_response_handler(QueryLogger()) - server.install_response_handler(StaleHandler()) - server.install_response_handler(IckyPtangZoopBoingBadHandler()) - server.install_response_handler(IckyPtangZoopBoingUglyHandler()) - server.install_response_handler(IckyPtangZoopBoingSlowHandler()) + server.install_response_handlers( + [ + QueryLogger(), + StaleHandler(), + IckyPtangZoopBoingBadHandler(), + IckyPtangZoopBoingUglyHandler(), + IckyPtangZoopBoingSlowHandler(), + ] + ) server.run() diff --git a/bin/tests/system/qmin/qmin_ans.py b/bin/tests/system/qmin/qmin_ans.py index c610eb5726..6185e15a10 100644 --- a/bin/tests/system/qmin/qmin_ans.py +++ b/bin/tests/system/qmin/qmin_ans.py @@ -16,7 +16,6 @@ from typing import AsyncGenerator import abc import dns.rcode -import dns.rdataclass import dns.rdatatype from isctest.asyncserver import ( @@ -26,8 +25,6 @@ from isctest.asyncserver import ( ResponseAction, ) -from isctest.compat import dns_rcode - def log_query(qctx: QueryContext) -> None: """ @@ -67,7 +64,7 @@ class EntRcodeChanger(DomainHandler): @property @abc.abstractmethod - def rcode(self) -> dns_rcode: + def rcode(self) -> dns.rcode.Rcode: raise NotImplementedError async def get_responses( diff --git a/bin/tests/system/rpzrecurse/ans5/ans.py b/bin/tests/system/rpzrecurse/ans5/ans.py index 44de641044..3132fca091 100644 --- a/bin/tests/system/rpzrecurse/ans5/ans.py +++ b/bin/tests/system/rpzrecurse/ans5/ans.py @@ -13,7 +13,9 @@ information regarding copyright ownership. from typing import AsyncGenerator -import dns +import dns.rcode +import dns.rdatatype +import dns.rrset from isctest.asyncserver import ( AsyncDnsServer, @@ -32,11 +34,10 @@ class ReplyA(ResponseHandler): self, qctx: QueryContext ) -> AsyncGenerator[DnsResponseSend, None]: a_rrset = dns.rrset.from_text( - qctx.qname, 300, dns.rdataclass.IN, dns.rdatatype.A, "10.53.0.5" + qctx.qname, 300, qctx.qclass, dns.rdatatype.A, "10.53.0.5" ) qctx.response.answer.append(a_rrset) - qctx.response.set_rcode(dns.rcode.NOERROR) - yield DnsResponseSend(qctx.response, authoritative=True) + yield DnsResponseSend(qctx.response) class IgnoreNs(ResponseHandler): @@ -49,19 +50,9 @@ class IgnoreNs(ResponseHandler): yield ResponseDrop() -class FallbackHandler(ResponseHandler): - async def get_responses( - self, qctx: QueryContext - ) -> AsyncGenerator[DnsResponseSend, None]: - qctx.response.set_rcode(dns.rcode.NOERROR) - yield DnsResponseSend(qctx.response, authoritative=True) - - def main() -> None: - server = AsyncDnsServer() - server.install_response_handler(ReplyA()) - server.install_response_handler(IgnoreNs()) - server.install_response_handler(FallbackHandler()) + server = AsyncDnsServer(default_aa=True, default_rcode=dns.rcode.NOERROR) + server.install_response_handlers([ReplyA(), IgnoreNs()]) server.run() diff --git a/bin/tests/system/statistics/ans4/ans.py b/bin/tests/system/statistics/ans4/ans.py index d3de81bc70..a5aa118ade 100644 --- a/bin/tests/system/statistics/ans4/ans.py +++ b/bin/tests/system/statistics/ans4/ans.py @@ -13,7 +13,9 @@ information regarding copyright ownership. from typing import AsyncGenerator -import dns +import dns.rcode +import dns.rdatatype +import dns.rrset from isctest.asyncserver import ( AsyncDnsServer, @@ -159,18 +161,19 @@ class FallbackHandler(ResponseHandler): def main() -> None: server = AsyncDnsServer(default_rcode=dns.rcode.NOERROR) - for handler in ( - BadGoodCnameHandler, - Cname1Handler, - Cname2Handler, - ExampleHandler, - FooInfoHandler, - NoDataHandler, - NxdomainHandler, - SubHandler, - FallbackHandler, - ): - server.install_response_handler(handler()) + server.install_response_handlers( + [ + BadGoodCnameHandler(), + Cname1Handler(), + Cname2Handler(), + ExampleHandler(), + FooInfoHandler(), + NoDataHandler(), + NxdomainHandler(), + SubHandler(), + FallbackHandler(), + ] + ) server.run() diff --git a/bin/tests/system/tsig/ans2/ans.py b/bin/tests/system/tsig/ans2/ans.py index 65548e69ef..677a57cf8f 100644 --- a/bin/tests/system/tsig/ans2/ans.py +++ b/bin/tests/system/tsig/ans2/ans.py @@ -40,7 +40,7 @@ class TruncatedWithLastByteDroppedHandler(ResponseHandler): def main() -> None: - server = AsyncDnsServer(acknowledge_tsig_dnspython_hacks=True) + server = AsyncDnsServer(keyring=None) server.install_response_handler(TruncatedWithLastByteDroppedHandler()) server.run() diff --git a/bin/tests/system/xfer/ans9/ans.py b/bin/tests/system/xfer/ans9/ans.py index 56c80becb9..2e3a1f8be2 100644 --- a/bin/tests/system/xfer/ans9/ans.py +++ b/bin/tests/system/xfer/ans9/ans.py @@ -13,8 +13,7 @@ information regarding copyright ownership. from typing import AsyncGenerator -import dns.message -import dns.rdataclass +import dns.rcode import dns.rdatatype import dns.rrset @@ -48,17 +47,17 @@ class AXFRServer(DomainHandler): # expected to send a SOA query over UDP and then an AXFR query over # TCP. Responses to both of those start with a SOA RRset in the ANSWER # section :-) - soa_message = dns.message.make_response(qctx.query) + soa_message = qctx.response soa_rrset = dns.rrset.from_text( qctx.qname, 300, - dns.rdataclass.IN, + qctx.qclass, dns.rdatatype.SOA, f". . {self.soa_version} 0 0 0 0", ) soa_message.answer.append(soa_rrset) - yield DnsResponseSend(soa_message, authoritative=True) + yield DnsResponseSend(soa_message) if qctx.qtype == dns.rdatatype.SOA: # If QTYPE=SOA, the SOA record is the complete response. @@ -77,35 +76,38 @@ class AXFRServer(DomainHandler): # Send just the obligatory NS RRset at zone apex in the next message. # This is stupidly inefficient, but makes looping below simpler as we # will already have been done with the mandatory stuff by then. - ns_message = dns.message.make_response(qctx.query) + ns_message = qctx.prepare_new_response() ns_rrset = dns.rrset.from_text( - qctx.qname, 300, dns.rdataclass.IN, dns.rdatatype.NS, "." + qctx.qname, 300, qctx.qclass, dns.rdatatype.NS, "." ) ns_message.answer.append(ns_rrset) - yield DnsResponseSend(ns_message, authoritative=True) + yield DnsResponseSend(ns_message) # Generate the AXFR with a txt rrset. - txt_message = dns.message.make_response(qctx.query) + txt_message = qctx.prepare_new_response() txt_rrset = dns.rrset.from_text( qctx.qname, 300, - dns.rdataclass.IN, + qctx.qclass, dns.rdatatype.TXT, "foo bar", ) txt_message.answer.append(txt_rrset) - yield DnsResponseSend(txt_message, authoritative=True) + yield DnsResponseSend(txt_message) # Finish the AXFR transaction by sending the second SOA RRset. - yield DnsResponseSend(soa_message, authoritative=True) + yield DnsResponseSend(soa_message) # This makes sure that the next SOA request causes a new zone transfer self.soa_version += 1 if __name__ == "__main__": - server = ControllableAsyncDnsServer([ToggleResponsesCommand]) + server = ControllableAsyncDnsServer( + default_aa=True, default_rcode=dns.rcode.NOERROR + ) + server.install_control_command(ToggleResponsesCommand()) server.install_response_handler(AXFRServer()) server.run() diff --git a/bin/tests/system/zero/ans5/ans.py b/bin/tests/system/zero/ans5/ans.py index 47b4da0ad0..a7f63913cf 100644 --- a/bin/tests/system/zero/ans5/ans.py +++ b/bin/tests/system/zero/ans5/ans.py @@ -14,10 +14,7 @@ information regarding copyright ownership. import ipaddress from typing import AsyncGenerator -import dns.flags -import dns.message -import dns.rdata -import dns.rdataclass +import dns.rcode import dns.rdatatype import dns.rrset @@ -48,12 +45,11 @@ class IncrementARecordHandler(ResponseHandler): qctx.response.answer.append(rrset) self._ip_address += 1 - qctx.response.set_rcode(dns.rcode.NOERROR) - yield DnsResponseSend(qctx.response, authoritative=True) + yield DnsResponseSend(qctx.response) def main() -> None: - server = AsyncDnsServer() + server = AsyncDnsServer(default_aa=True, default_rcode=dns.rcode.NOERROR) server.install_response_handler(IncrementARecordHandler()) server.run()