Sync asyncserver.py with the development branch

Import bin/tests/system/isctest/asyncserver.py as present in commit
ced002c4ab on the "main" branch.  This
enables using newer asyncserver.py infrastructure code in system tests
that need to be backported to maintenance branches.
This commit is contained in:
Michał Kępień 2026-04-17 17:57:05 +02:00
parent d2a67ba222
commit b0e8966647
No known key found for this signature in database

View file

@ -11,20 +11,9 @@ See the COPYRIGHT file distributed with this work for additional
information regarding copyright ownership.
"""
from collections.abc import AsyncGenerator, Callable, Coroutine, Sequence
from dataclasses import dataclass, field
from typing import (
Any,
AsyncGenerator,
Callable,
Coroutine,
Dict,
List,
Optional,
Set,
Tuple,
Union,
cast,
)
from typing import Any, cast
import abc
import asyncio
@ -52,11 +41,10 @@ import dns.rdataset
import dns.rdatatype
import dns.rrset
import dns.tsig
import dns.version
import dns.zone
_UdpHandler = Callable[
[bytes, Tuple[str, int], asyncio.DatagramTransport], Coroutine[Any, Any, None]
[bytes, tuple[str, int], asyncio.DatagramTransport], Coroutine[Any, Any, None]
]
@ -74,7 +62,7 @@ class _AsyncUdpHandler(asyncio.DatagramProtocol):
self,
handler: _UdpHandler,
) -> None:
self._transport: Optional[asyncio.DatagramTransport] = None
self._transport: asyncio.DatagramTransport | None = None
self._handler: _UdpHandler = handler
def connection_made(self, transport: asyncio.BaseTransport) -> None:
@ -83,7 +71,7 @@ class _AsyncUdpHandler(asyncio.DatagramProtocol):
"""
self._transport = cast(asyncio.DatagramTransport, transport)
def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None:
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
"""
Called by asyncio when a datagram is received.
"""
@ -108,9 +96,9 @@ class AsyncServer:
def __init__(
self,
udp_handler: Optional[_UdpHandler],
tcp_handler: Optional[_TcpHandler],
pidfile: Optional[str] = None,
udp_handler: _UdpHandler | None,
tcp_handler: _TcpHandler | None,
pidfile: str | None = None,
) -> None:
logging.basicConfig(
format="%(asctime)s %(levelname)8s %(message)s",
@ -132,12 +120,12 @@ class AsyncServer:
logging.info("Setting up IPv4 listener at %s:%d", ipv4_address, port)
logging.info("Setting up IPv6 listener at [%s]:%d", ipv6_address, port)
self._ip_addresses: Tuple[str, str] = (ipv4_address, ipv6_address)
self._ip_addresses: tuple[str, str] = (ipv4_address, ipv6_address)
self._port: int = port
self._udp_handler: Optional[_UdpHandler] = udp_handler
self._tcp_handler: Optional[_TcpHandler] = tcp_handler
self._pidfile: Optional[str] = pidfile
self._work_done: Optional[asyncio.Future] = None
self._udp_handler: _UdpHandler | None = udp_handler
self._tcp_handler: _TcpHandler | None = tcp_handler
self._pidfile: str | None = pidfile
self._work_done: asyncio.Future | None = None
def _get_ipv4_address_from_directory_name(self) -> str:
containing_directory = pathlib.Path().absolute().stem
@ -185,7 +173,7 @@ class AsyncServer:
loop.set_exception_handler(self._handle_exception)
def _handle_exception(
self, _: asyncio.AbstractEventLoop, context: Dict[str, Any]
self, _: asyncio.AbstractEventLoop, context: dict[str, Any]
) -> None:
assert self._work_done
exception = context.get("exception", RuntimeError(context["message"]))
@ -265,17 +253,16 @@ class QueryContext:
query: dns.message.Message
response: dns.message.Message
socket: Peer
peer: Peer
protocol: DnsProtocol
zone: Optional[dns.zone.Zone] = field(default=None, init=False)
soa: Optional[dns.rrset.RRset] = field(default=None, init=False)
node: Optional[dns.node.Node] = field(default=None, init=False)
answer: Optional[dns.rdataset.Rdataset] = field(default=None, init=False)
alias: Optional[dns.name.Name] = field(default=None, init=False)
_initialized_response: Optional[dns.message.Message] = field(
default=None, init=False
)
_initialized_response_with_zone_data: Optional[dns.message.Message] = field(
zone: dns.zone.Zone | None = field(default=None, init=False)
soa: dns.rrset.RRset | None = field(default=None, init=False)
node: dns.node.Node | None = field(default=None, init=False)
answer: dns.rdataset.Rdataset | None = field(default=None, init=False)
alias: dns.name.Name | None = field(default=None, init=False)
_initialized_response: dns.message.Message | None = field(default=None, init=False)
_initialized_response_with_zone_data: dns.message.Message | None = field(
default=None, init=False
)
@ -320,7 +307,7 @@ class ResponseAction(abc.ABC):
"""
@abc.abstractmethod
async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
async def perform(self) -> dns.message.Message | bytes | None:
"""
This method is expected to carry out arbitrary actions (e.g. wait for a
specific amount of time, modify the answer, etc.) and then return the
@ -343,14 +330,30 @@ class DnsResponseSend(ResponseAction):
"""
response: dns.message.Message
authoritative: Optional[bool] = None
authoritative: bool | None = None
delay: float = 0.0
acknowledge_hand_rolled_response: bool = False
async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
async def perform(self) -> dns.message.Message | bytes | None:
"""
Yield a potentially delayed response that is a dns.message.Message.
"""
assert isinstance(self.response, dns.message.Message)
if not (
_is_asyncserver_response(self.response)
or self.acknowledge_hand_rolled_response
):
error = "The response you are trying to send was not created using "
error += "AsyncDnsServer's response preparation methods. "
error += "This will break features such as automatic AA flag "
error += "and RCODE handling. If you need a fresh copy of a "
error += "response, use `QueryContext.prepare_new_response` "
error += "instead of `dns.message.make_response`. "
error += "To acknowledge this and proceed anyway, set "
error += "`acknowledge_hand_rolled_response=True` in "
error += "DnsResponseSend's constructor."
raise RuntimeError(error)
if self.authoritative is not None:
if self.authoritative:
self.response.flags |= dns.flags.AA
@ -377,7 +380,7 @@ class BytesResponseSend(ResponseAction):
response: bytes
delay: float = 0.0
async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
async def perform(self) -> dns.message.Message | bytes | None:
"""
Yield a potentially delayed response that is a sequence of bytes.
"""
@ -394,7 +397,7 @@ class ResponseDrop(ResponseAction):
Action which does nothing - as if a packet was dropped.
"""
async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
async def perform(self) -> dns.message.Message | bytes | None:
return None
@ -403,17 +406,16 @@ class _ConnectionTeardownRequested(Exception):
@dataclass
class ResponseDropAndCloseConnection(ResponseAction):
class CloseConnection(ResponseAction):
"""
Action which makes the server close the connection after the DNS query is
received by the server (TCP only).
Action which makes the server close the connection (TCP only).
The connection may be closed with a delay if requested.
"""
delay: float = 0.0
async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
async def perform(self) -> dns.message.Message | bytes | None:
if self.delay > 0:
logging.info("Waiting %.1fs before closing TCP connection", self.delay)
await asyncio.sleep(self.delay)
@ -495,7 +497,7 @@ class IgnoreAllConnections(ConnectionHandler):
client socket, effectively ignoring all incoming connections.
"""
_connections: Set[asyncio.StreamWriter] = field(default_factory=set)
_connections: set[asyncio.StreamWriter] = field(default_factory=set)
async def handle(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer: Peer
@ -529,8 +531,8 @@ class ConnectionReset(ConnectionHandler):
make the server send an RST segment; this happens when the server closes a
client's socket while there is still unread data in that socket's buffer.
If closing the connection _after_ the query is read by the server is enough
for a given use case, the ResponseDropAndCloseConnection response handler
should be used instead.
for a given use case, the CloseConnection response handler should be used
instead.
"""
delay: float = 0.0
@ -606,14 +608,14 @@ class QnameHandler(ResponseHandler):
@property
@abc.abstractmethod
def qnames(self) -> List[str]:
def qnames(self) -> list[str]:
"""
A list of QNAMEs handled by this class.
"""
raise NotImplementedError
def __init__(self) -> None:
self._qnames: List[dns.name.Name] = [dns.name.from_text(d) for d in self.qnames]
self._qnames: list[dns.name.Name] = [dns.name.from_text(d) for d in self.qnames]
def __str__(self) -> str:
return f"{self.__class__.__name__}(QNAMEs: {', '.join(self.qnames)})"
@ -626,6 +628,105 @@ class QnameHandler(ResponseHandler):
return qctx.qname in self._qnames
class QnameQtypeHandler(QnameHandler):
"""
Handle queries for which both of the following conditions are true:
- the query's QNAME is present in `self.qnames`,
- the query's QTYPE is present in `self.qtypes`.
"""
@property
@abc.abstractmethod
def qtypes(self) -> list[dns.rdatatype.RdataType]:
"""
A list of QTYPEs handled by this class.
"""
raise NotImplementedError
def __init__(self) -> None:
super().__init__()
self._qtypes: list[dns.rdatatype.RdataType] = self.qtypes
def __str__(self) -> str:
return f"{self.__class__.__name__}(QNAMEs: {', '.join(self.qnames)}; QTYPEs: {', '.join(map(str, self.qtypes))})"
def match(self, qctx: QueryContext) -> bool:
"""
Handle queries whose QNAME and QTYPE match any of the QNAMEs and
QTYPEs handled by this class.
"""
return qctx.qtype in self._qtypes and super().match(qctx)
class StaticResponseHandler(ResponseHandler):
"""
Base class used for deriving custom static response handlers.
The derived class can specify the RRsets to be included in the answer,
authority, and additional sections of the response, whether to set the AA
bit in the response, and a delay before sending the response.
The default implementation of `get_responses()` uses these properties to
prepare and yield a single response.
"""
@property
def rcode(self) -> dns.rcode.Rcode | None:
"""
Optional RCODE to be set in the response.
"""
return None
@property
def answer(self) -> Sequence[dns.rrset.RRset]:
"""
RRsets to be included in the answer section of the response.
"""
return []
@property
def authority(self) -> Sequence[dns.rrset.RRset]:
"""
RRsets to be included in the authority section of the response.
"""
return []
@property
def additional(self) -> Sequence[dns.rrset.RRset]:
"""
RRsets to be included in the additional section of the response.
"""
return []
@property
def authoritative(self) -> bool | None:
"""
Whether to set the AA bit in the response.
"""
return None
@property
def delay(self) -> float:
"""
Delay before sending the response.
"""
return 0.0
async def get_responses(
self, qctx: QueryContext
) -> AsyncGenerator[DnsResponseSend, None]:
qctx.prepare_new_response(with_zone_data=False)
qctx.response.answer.extend(self.answer)
qctx.response.authority.extend(self.authority)
qctx.response.additional.extend(self.additional)
if self.rcode is not None:
qctx.response.set_rcode(self.rcode)
yield DnsResponseSend(
qctx.response, authoritative=self.authoritative, delay=self.delay
)
class DomainHandler(ResponseHandler):
"""
Base class used for deriving custom domain handlers.
@ -633,20 +734,28 @@ class DomainHandler(ResponseHandler):
The derived class must specify a list of `domains` that it wants to handle.
Queries for any of these domains (and their subdomains) will then be passed
to the `get_response()` method in the derived class.
The most specific matching domain is stored in the `matched_domain` attribute.
"""
@property
@abc.abstractmethod
def domains(self) -> List[str]:
def domains(self) -> list[str]:
"""
A list of domain names handled by this class.
"""
raise NotImplementedError
def __init__(self) -> None:
self._domains: List[dns.name.Name] = [
dns.name.from_text(d) for d in self.domains
]
self._domains: list[dns.name.Name] = sorted(
[dns.name.from_text(d) for d in self.domains], reverse=True
)
self._matched_domain: dns.name.Name | None = None
@property
def matched_domain(self) -> dns.name.Name:
assert self._matched_domain is not None
return self._matched_domain
def __str__(self) -> str:
return f"{self.__class__.__name__}(domains: {', '.join(self.domains)})"
@ -656,20 +765,124 @@ class DomainHandler(ResponseHandler):
Handle queries whose QNAME matches any of the domains handled by this
class.
"""
self._matched_domain = None
for domain in self._domains:
if qctx.qname.is_subdomain(domain):
self._matched_domain = domain
return True
return False
class ForwarderHandler(ResponseHandler):
"""
A handler forwarding all received queries to another DNS server with an
optional delay and then relaying the responses back to the original client.
Queries are currently always forwarded via UDP.
"""
@property
@abc.abstractmethod
def target(self) -> str:
"""
The address of the DNS server to forward queries to.
"""
raise NotImplementedError
@property
def port(self) -> int:
"""
The port of the DNS server to forward queries to.
The default value of 0 causes the same port as the one used by this
server for listening to be used.
"""
return 0
@property
def delay(self) -> float:
"""
The number of seconds to wait before forwarding each query.
"""
return 0.0
def __str__(self) -> str:
return f"{self.__class__.__name__}(target: {self.target}:{self.port})"
class ForwarderProtocol(asyncio.DatagramProtocol):
def __init__(self, query: bytes, response: asyncio.Future) -> None:
self._query = query
self._response = response
def connection_made(self, transport: asyncio.BaseTransport) -> None:
logging.debug("[OUT] %s", self._query.hex())
cast(asyncio.DatagramTransport, transport).sendto(self._query)
def datagram_received(self, data: bytes, _: tuple[str, int]) -> None:
logging.debug("[IN] %s", data.hex())
self._response.set_result(data)
async def get_responses(
self, qctx: QueryContext
) -> AsyncGenerator[ResponseAction, None]:
loop = asyncio.get_running_loop()
response = loop.create_future()
forwarding_target = f"{self.target}:{self.port or qctx.socket.port}"
if self.delay > 0:
logging.info(
"Waiting %.1fs before forwarding %s query from %s to %s over UDP",
self.delay,
qctx.protocol.name,
qctx.peer,
forwarding_target,
)
await asyncio.sleep(self.delay)
logging.info(
"Forwarding %s query from %s to %s over UDP",
qctx.protocol.name,
qctx.peer,
forwarding_target,
)
transport, _ = await loop.create_datagram_endpoint(
lambda: self.ForwarderProtocol(qctx.query.to_wire(), response),
local_addr=(qctx.socket.host, 0),
remote_addr=(self.target, self.port or qctx.socket.port),
)
try:
await response
finally:
transport.close()
logging.info(
"Relaying UDP response from %s to %s over %s",
forwarding_target,
qctx.peer,
qctx.protocol.name,
)
try:
message = _DnsMessageWithTsigDisabled.from_wire(response.result())
yield DnsResponseSend(message, acknowledge_hand_rolled_response=True)
except dns.exception.DNSException:
logging.warning(
"Failed to parse response from %s as a DNS message, relaying it as raw bytes",
forwarding_target,
)
yield BytesResponseSend(response.result())
@dataclass
class _ZoneTreeNode:
"""
A node representing a zone with one origin.
"""
zone: Optional[dns.zone.Zone]
children: List["_ZoneTreeNode"] = field(default_factory=list)
zone: dns.zone.Zone | None
children: list["_ZoneTreeNode"] = field(default_factory=list)
class _ZoneTree:
@ -719,7 +932,7 @@ class _ZoneTree:
node_from.children.remove(child)
node_to.children.append(child)
def find_best_zone(self, name: dns.name.Name) -> Optional[dns.zone.Zone]:
def find_best_zone(self, name: dns.name.Name) -> dns.zone.Zone | None:
"""
Return the closest matching zone (if any) for the domain name.
"""
@ -737,7 +950,7 @@ class _DnsMessageWithTsigDisabled(dns.message.Message):
"""
class _DisableTsigHandling(contextlib.ContextDecorator):
def __init__(self, message: Optional[dns.message.Message] = None) -> None:
def __init__(self, message: dns.message.Message | None = None) -> None:
self.original_tsig_sign = dns.tsig.sign
self.original_tsig_validate = dns.tsig.validate
if message:
@ -749,7 +962,7 @@ class _DnsMessageWithTsigDisabled(dns.message.Message):
from failing on messages initialized with `dns.message.from_wire(keyring=False)`.
"""
def sign(*_: Any, **__: Any) -> Tuple[dns.rdata.Rdata, None]:
def sign(*_: Any, **__: Any) -> tuple[dns.rdata.Rdata, None]:
assert self.tsig
return self.tsig[0], None
@ -792,6 +1005,19 @@ class _NoKeyringType:
pass
_ASYNCSERVER_RESPONSE_MARKER = "__is_asyncserver_response__"
def _make_asyncserver_response(query: dns.message.Message) -> dns.message.Message:
response = dns.message.make_response(query)
setattr(response, _ASYNCSERVER_RESPONSE_MARKER, True)
return response
def _is_asyncserver_response(message: dns.message.Message) -> bool:
return getattr(message, _ASYNCSERVER_RESPONSE_MARKER, False)
class AsyncDnsServer(AsyncServer):
"""
DNS server which responds to queries based on zone data and/or custom
@ -812,17 +1038,17 @@ class AsyncDnsServer(AsyncServer):
self,
/,
default_rcode: dns.rcode.Rcode = dns.rcode.REFUSED,
default_aa: bool = True,
keyring: Union[
Dict[dns.name.Name, dns.tsig.Key], None, _NoKeyringType
] = _NoKeyringType(),
default_aa: bool = False,
keyring: (
dict[dns.name.Name, dns.tsig.Key] | None | _NoKeyringType
) = _NoKeyringType(),
acknowledge_manual_dname_handling: bool = False,
) -> None:
super().__init__(self._handle_udp, self._handle_tcp, "ans.pid")
self._zone_tree: _ZoneTree = _ZoneTree()
self._connection_handler: Optional[ConnectionHandler] = None
self._response_handlers: List[ResponseHandler] = []
self._connection_handler: ConnectionHandler | None = None
self._response_handlers: list[ResponseHandler] = []
self._default_rcode = default_rcode
self._default_aa = default_aa
self._keyring = keyring
@ -849,10 +1075,18 @@ class AsyncDnsServer(AsyncServer):
else:
self._response_handlers.append(handler)
def install_response_handlers(self, handlers: List[ResponseHandler]) -> None:
def install_response_handlers(self, *handlers: ResponseHandler) -> None:
for handler in handlers:
self.install_response_handler(handler)
def replace_response_handlers(self, *new_handlers: ResponseHandler) -> None:
"""
Uninstall all currently installed handlers and install the provided ones.
"""
logging.info("Uninstalling response handlers: %s", str(self._response_handlers))
self._response_handlers.clear()
self.install_response_handlers(*new_handlers)
def uninstall_response_handler(self, handler: ResponseHandler) -> None:
"""
Remove the specified handler from the list of response handlers.
@ -923,11 +1157,13 @@ class AsyncDnsServer(AsyncServer):
raise ValueError(error)
async def _handle_udp(
self, wire: bytes, addr: Tuple[str, int], transport: asyncio.DatagramTransport
self, wire: bytes, addr: tuple[str, int], transport: asyncio.DatagramTransport
) -> None:
logging.debug("Received UDP message: %s", wire.hex())
socket_info = transport.get_extra_info("sockname")
socket = Peer(socket_info[0], socket_info[1])
peer = Peer(addr[0], addr[1])
responses = self._handle_query(wire, peer, DnsProtocol.UDP)
responses = self._handle_query(wire, socket, peer, DnsProtocol.UDP)
async for response in responses:
logging.debug("Sending UDP message: %s", response.hex())
transport.sendto(response, addr)
@ -964,7 +1200,7 @@ class AsyncDnsServer(AsyncServer):
async def _read_tcp_query(
self, reader: asyncio.StreamReader, peer: Peer
) -> Optional[bytes]:
) -> bytes | None:
wire_length = await self._read_tcp_query_wire_length(reader, peer)
if not wire_length:
return None
@ -973,7 +1209,7 @@ class AsyncDnsServer(AsyncServer):
async def _read_tcp_query_wire_length(
self, reader: asyncio.StreamReader, peer: Peer
) -> Optional[int]:
) -> int | None:
logging.debug("Receiving TCP message length from %s...", peer)
wire_length_bytes = await self._read_tcp_octets(reader, peer, 2)
@ -986,7 +1222,7 @@ class AsyncDnsServer(AsyncServer):
async def _read_tcp_query_wire(
self, reader: asyncio.StreamReader, peer: Peer, wire_length: int
) -> Optional[bytes]:
) -> bytes | None:
logging.debug("Receiving TCP message (%d octets) from %s...", wire_length, peer)
wire = await self._read_tcp_octets(reader, peer, wire_length)
@ -999,7 +1235,7 @@ class AsyncDnsServer(AsyncServer):
async def _read_tcp_octets(
self, reader: asyncio.StreamReader, peer: Peer, expected: int
) -> Optional[bytes]:
) -> bytes | None:
buffer = b""
while len(buffer) < expected:
@ -1024,39 +1260,39 @@ class AsyncDnsServer(AsyncServer):
async def _send_tcp_response(
self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes
) -> None:
responses = self._handle_query(wire, peer, DnsProtocol.TCP)
socket_info = writer.get_extra_info("sockname")
socket = Peer(socket_info[0], socket_info[1])
responses = self._handle_query(wire, socket, peer, DnsProtocol.TCP)
async for response in responses:
logging.debug("Sending TCP response: %s", response.hex())
writer.write(response)
await writer.drain()
def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None:
def _log_query(self, qctx: QueryContext) -> None:
logging.info(
"Received %s/%s/%s (ID=%d) query from %s (%s)",
"Received %s/%s/%s (ID=%d) query from %s on %s (%s)",
qctx.qname.to_text(omit_final_dot=True),
dns.rdataclass.to_text(qctx.qclass),
dns.rdatatype.to_text(qctx.qtype),
qctx.query.id,
peer,
protocol.name,
qctx.peer,
qctx.socket,
qctx.protocol.name,
)
logging.debug(
"\n".join([f"[IN] {l}" for l in [""] + str(qctx.query).splitlines()])
)
def _log_response(
self,
qctx: QueryContext,
response: Optional[Union[dns.message.Message, bytes]],
peer: Peer,
protocol: DnsProtocol,
self, qctx: QueryContext, response: dns.message.Message | bytes | None
) -> None:
if not response:
logging.info(
"Not sending a response to query (ID=%d) from %s (%s)",
"Not sending a response to query (ID=%d) from %s on %s (%s)",
qctx.query.id,
peer,
protocol.name,
qctx.peer,
qctx.socket,
qctx.protocol.name,
)
return
@ -1071,7 +1307,7 @@ class AsyncDnsServer(AsyncServer):
qtype = "-"
logging.info(
"Sending %s/%s/%s (ID=%d) response (%d/%d/%d/%d) to a query (ID=%d) from %s (%s)",
"Sending %s/%s/%s (ID=%d) response (%d/%d/%d/%d) to a query (ID=%d) from %s on %s (%s)",
qname,
qclass,
qtype,
@ -1081,8 +1317,9 @@ class AsyncDnsServer(AsyncServer):
len(response.authority),
len(response.additional),
qctx.query.id,
peer,
protocol.name,
qctx.peer,
qctx.socket,
qctx.protocol.name,
)
logging.debug(
"\n".join([f"[OUT] {l}" for l in [""] + str(response).splitlines()])
@ -1090,16 +1327,17 @@ class AsyncDnsServer(AsyncServer):
return
logging.info(
"Sending response (%d bytes) to a query (ID=%d) from %s (%s)",
"Sending response (%d bytes) to a query (ID=%d) from %s on %s (%s)",
len(response),
qctx.query.id,
peer,
protocol.name,
qctx.peer,
qctx.socket,
qctx.protocol.name,
)
logging.debug("[OUT] %s", response.hex())
async def _handle_query(
self, wire: bytes, peer: Peer, protocol: DnsProtocol
self, wire: bytes, socket: Peer, peer: Peer, protocol: DnsProtocol
) -> AsyncGenerator[bytes, None]:
"""
Yield wire data to send as a response over the established transport.
@ -1109,12 +1347,12 @@ class AsyncDnsServer(AsyncServer):
except dns.exception.DNSException as exc:
logging.error("Invalid query from %s (%s): %s", peer, wire.hex(), exc)
return
response_stub = dns.message.make_response(query)
qctx = QueryContext(query, response_stub, peer, protocol)
self._log_query(qctx, peer, protocol)
response_stub = _make_asyncserver_response(query)
qctx = QueryContext(query, response_stub, socket, peer, protocol)
self._log_query(qctx)
responses = self._prepare_responses(qctx)
async for response in responses:
self._log_response(qctx, response, peer, protocol)
self._log_response(qctx, response)
if response:
if isinstance(response, dns.message.Message):
response = response.to_wire(max_size=65535)
@ -1146,7 +1384,7 @@ class AsyncDnsServer(AsyncServer):
async def _prepare_responses(
self, qctx: QueryContext
) -> AsyncGenerator[Optional[Union[dns.message.Message, bytes]], None]:
) -> AsyncGenerator[dns.message.Message | bytes | None, None]:
"""
Yield response(s) either from response handlers or zone data.
"""
@ -1339,10 +1577,10 @@ class ControllableAsyncDnsServer(AsyncDnsServer):
return dns.name.from_text(self._CONTROL_DOMAIN)
@functools.cached_property
def _commands(self) -> Dict[dns.name.Name, "ControlCommand"]:
def _commands(self) -> dict[dns.name.Name, "ControlCommand"]:
return {}
def install_control_commands(self, commands: List["ControlCommand"]) -> None:
def install_control_commands(self, *commands: "ControlCommand") -> None:
for command in commands:
self.install_control_command(command)
@ -1360,7 +1598,7 @@ class ControllableAsyncDnsServer(AsyncDnsServer):
async def _prepare_responses(
self, qctx: QueryContext
) -> AsyncGenerator[Optional[Union[dns.message.Message, bytes]], None]:
) -> AsyncGenerator[dns.message.Message | bytes | None, None]:
"""
Detect and handle control queries, falling back to normal processing
for non-control queries.
@ -1373,9 +1611,7 @@ class ControllableAsyncDnsServer(AsyncDnsServer):
async for response in super()._prepare_responses(qctx):
yield response
def _handle_control_command(
self, qctx: QueryContext
) -> Optional[dns.message.Message]:
def _handle_control_command(self, qctx: QueryContext) -> dns.message.Message | None:
"""
Detect and handle control queries.
@ -1450,8 +1686,8 @@ class ControlCommand(abc.ABC):
@abc.abstractmethod
def handle(
self, args: List[str], server: ControllableAsyncDnsServer, qctx: QueryContext
) -> Optional[str]:
self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
) -> str | None:
"""
This method is expected to carry out arbitrary actions in response to a
control query. Note that it is invoked synchronously (it is not a
@ -1489,11 +1725,11 @@ class ToggleResponsesCommand(ControlCommand):
control_subdomain = "send-responses"
def __init__(self) -> None:
self._current_handler: Optional[IgnoreAllQueries] = None
self._current_handler: IgnoreAllQueries | None = None
def handle(
self, args: List[str], server: ControllableAsyncDnsServer, qctx: QueryContext
) -> Optional[str]:
self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
) -> str | None:
if len(args) != 1:
logging.error("Invalid %s query %s", self, qctx.qname)
qctx.response.set_rcode(dns.rcode.SERVFAIL)
@ -1518,3 +1754,30 @@ class ToggleResponsesCommand(ControlCommand):
logging.error("Unrecognized response sending mode '%s'", mode)
qctx.response.set_rcode(dns.rcode.SERVFAIL)
return f"unrecognized response sending mode '{mode}'"
class SwitchControlCommand(ControlCommand):
"""
Switch the server's response handlers based on the control query.
A sequence of response handlers is associated with each key. When a
control query is received, the server's response handlers are replaced
with the sequence associated with the key extracted from the control
query.
"""
control_subdomain = "switch"
def __init__(self, handler_mapping: dict[str, Sequence[ResponseHandler]]):
self._handler_mapping = handler_mapping
def handle(
self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
) -> str | None:
if len(args) != 1 or args[0] not in self._handler_mapping:
logging.error("Invalid %s query %s", self, qctx.qname)
qctx.response.set_rcode(dns.rcode.SERVFAIL)
return f"invalid query; exactly one of {list(self._handler_mapping.keys())} is expected in QNAME"
server.replace_response_handlers(*self._handler_mapping[args[0]])
return f"switched to handler set '{args[0]}'"