diff --git a/bin/tests/system/isctest/asyncserver.py b/bin/tests/system/isctest/asyncserver.py index dd784ef5b6..8e4ea245e5 100644 --- a/bin/tests/system/isctest/asyncserver.py +++ b/bin/tests/system/isctest/asyncserver.py @@ -266,6 +266,7 @@ class QueryContext: query: dns.message.Message response: dns.message.Message + socket: Peer peer: Peer protocol: DnsProtocol zone: Optional[dns.zone.Zone] = field(default=None, init=False) @@ -787,6 +788,108 @@ class DomainHandler(ResponseHandler): return False +class ForwarderHandler(ResponseHandler): + """ + A handler forwarding all received queries to another DNS server with an + optional delay and then relaying the responses back to the original client. + + Queries are currently always forwarded via UDP. + """ + + @property + @abc.abstractmethod + def target(self) -> str: + """ + The address of the DNS server to forward queries to. + """ + raise NotImplementedError + + @property + def port(self) -> int: + """ + The port of the DNS server to forward queries to. + + The default value of 0 causes the same port as the one used by this + server for listening to be used. + """ + return 0 + + @property + def delay(self) -> float: + """ + The number of seconds to wait before forwarding each query. + """ + return 0.0 + + def __str__(self) -> str: + return f"{self.__class__.__name__}(target: {self.target}:{self.port})" + + class ForwarderProtocol(asyncio.DatagramProtocol): + def __init__(self, query: bytes, response: asyncio.Future) -> None: + self._query = query + self._response = response + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + logging.debug("[OUT] %s", self._query.hex()) + cast(asyncio.DatagramTransport, transport).sendto(self._query) + + def datagram_received(self, data: bytes, _: Tuple[str, int]) -> None: + logging.debug("[IN] %s", data.hex()) + self._response.set_result(data) + + async def get_responses( + self, qctx: QueryContext + ) -> AsyncGenerator[ResponseAction, None]: + loop = asyncio.get_running_loop() + response = loop.create_future() + forwarding_target = f"{self.target}:{self.port or qctx.socket.port}" + + if self.delay > 0: + logging.info( + "Waiting %.1fs before forwarding %s query from %s to %s over UDP", + self.delay, + qctx.protocol.name, + qctx.peer, + forwarding_target, + ) + await asyncio.sleep(self.delay) + + logging.info( + "Forwarding %s query from %s to %s over UDP", + qctx.protocol.name, + qctx.peer, + forwarding_target, + ) + + transport, _ = await loop.create_datagram_endpoint( + lambda: self.ForwarderProtocol(qctx.query.to_wire(), response), + local_addr=(qctx.socket.host, 0), + remote_addr=(self.target, self.port or qctx.socket.port), + ) + + try: + await response + finally: + transport.close() + + logging.info( + "Relaying UDP response from %s to %s over %s", + forwarding_target, + qctx.peer, + qctx.protocol.name, + ) + + try: + message = _DnsMessageWithTsigDisabled.from_wire(response.result()) + yield DnsResponseSend(message, acknowledge_hand_rolled_response=True) + except dns.exception.DNSException: + logging.warning( + "Failed to parse response from %s as a DNS message, relaying it as raw bytes", + forwarding_target, + ) + yield BytesResponseSend(response.result()) + + @dataclass class _ZoneTreeNode: """ @@ -1072,8 +1175,10 @@ class AsyncDnsServer(AsyncServer): self, wire: bytes, addr: Tuple[str, int], transport: asyncio.DatagramTransport ) -> None: logging.debug("Received UDP message: %s", wire.hex()) + socket_info = transport.get_extra_info("sockname") + socket = Peer(socket_info[0], socket_info[1]) peer = Peer(addr[0], addr[1]) - responses = self._handle_query(wire, peer, DnsProtocol.UDP) + responses = self._handle_query(wire, socket, peer, DnsProtocol.UDP) async for response in responses: logging.debug("Sending UDP message: %s", response.hex()) transport.sendto(response, addr) @@ -1170,39 +1275,39 @@ class AsyncDnsServer(AsyncServer): async def _send_tcp_response( self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes ) -> None: - responses = self._handle_query(wire, peer, DnsProtocol.TCP) + socket_info = writer.get_extra_info("sockname") + socket = Peer(socket_info[0], socket_info[1]) + responses = self._handle_query(wire, socket, peer, DnsProtocol.TCP) async for response in responses: logging.debug("Sending TCP response: %s", response.hex()) writer.write(response) await writer.drain() - def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None: + def _log_query(self, qctx: QueryContext) -> None: logging.info( - "Received %s/%s/%s (ID=%d) query from %s (%s)", + "Received %s/%s/%s (ID=%d) query from %s on %s (%s)", qctx.qname.to_text(omit_final_dot=True), dns.rdataclass.to_text(qctx.qclass), dns.rdatatype.to_text(qctx.qtype), qctx.query.id, - peer, - protocol.name, + qctx.peer, + qctx.socket, + qctx.protocol.name, ) logging.debug( "\n".join([f"[IN] {l}" for l in [""] + str(qctx.query).splitlines()]) ) def _log_response( - self, - qctx: QueryContext, - response: Optional[Union[dns.message.Message, bytes]], - peer: Peer, - protocol: DnsProtocol, + self, qctx: QueryContext, response: Optional[Union[dns.message.Message, bytes]] ) -> None: if not response: logging.info( - "Not sending a response to query (ID=%d) from %s (%s)", + "Not sending a response to query (ID=%d) from %s on %s (%s)", qctx.query.id, - peer, - protocol.name, + qctx.peer, + qctx.socket, + qctx.protocol.name, ) return @@ -1217,7 +1322,7 @@ class AsyncDnsServer(AsyncServer): qtype = "-" logging.info( - "Sending %s/%s/%s (ID=%d) response (%d/%d/%d/%d) to a query (ID=%d) from %s (%s)", + "Sending %s/%s/%s (ID=%d) response (%d/%d/%d/%d) to a query (ID=%d) from %s on %s (%s)", qname, qclass, qtype, @@ -1227,8 +1332,9 @@ class AsyncDnsServer(AsyncServer): len(response.authority), len(response.additional), qctx.query.id, - peer, - protocol.name, + qctx.peer, + qctx.socket, + qctx.protocol.name, ) logging.debug( "\n".join([f"[OUT] {l}" for l in [""] + str(response).splitlines()]) @@ -1236,16 +1342,17 @@ class AsyncDnsServer(AsyncServer): return logging.info( - "Sending response (%d bytes) to a query (ID=%d) from %s (%s)", + "Sending response (%d bytes) to a query (ID=%d) from %s on %s (%s)", len(response), qctx.query.id, - peer, - protocol.name, + qctx.peer, + qctx.socket, + qctx.protocol.name, ) logging.debug("[OUT] %s", response.hex()) async def _handle_query( - self, wire: bytes, peer: Peer, protocol: DnsProtocol + self, wire: bytes, socket: Peer, peer: Peer, protocol: DnsProtocol ) -> AsyncGenerator[bytes, None]: """ Yield wire data to send as a response over the established transport. @@ -1256,11 +1363,11 @@ class AsyncDnsServer(AsyncServer): logging.error("Invalid query from %s (%s): %s", peer, wire.hex(), exc) return response_stub = _make_asyncserver_response(query) - qctx = QueryContext(query, response_stub, peer, protocol) - self._log_query(qctx, peer, protocol) + qctx = QueryContext(query, response_stub, socket, peer, protocol) + self._log_query(qctx) responses = self._prepare_responses(qctx) async for response in responses: - self._log_response(qctx, response, peer, protocol) + self._log_response(qctx, response) if response: if isinstance(response, dns.message.Message): response = response.to_wire(max_size=65535) diff --git a/bin/tests/system/pipelined/ans5/ans.py b/bin/tests/system/pipelined/ans5/ans.py index 51f10ba3c7..268687e9d7 100644 --- a/bin/tests/system/pipelined/ans5/ans.py +++ b/bin/tests/system/pipelined/ans5/ans.py @@ -1,211 +1,28 @@ -# Copyright (C) Internet Systems Consortium, Inc. ("ISC") -# -# SPDX-License-Identifier: MPL-2.0 -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, you can obtain one at https://mozilla.org/MPL/2.0/. -# -# See the COPYRIGHT file distributed with this work for additional -# information regarding copyright ownership. +""" +Copyright (C) Internet Systems Consortium, Inc. ("ISC") -############################################################################ -# -# This tool acts as a TCP/UDP proxy and delays all incoming packets by 500 -# milliseconds. -# -# We use it to check pipelining - a client sents 8 questions over a -# pipelined connection - that require asking a normal (examplea) and a -# slow-responding (exampleb) servers: -# a.examplea -# a.exampleb -# b.examplea -# b.exampleb -# c.examplea -# c.exampleb -# d.examplea -# d.exampleb -# -# If pipelining works properly the answers will be returned out of order -# with all answers from examplea returned first, and then all answers -# from exampleb. -# -############################################################################ +SPDX-License-Identifier: MPL-2.0 -from __future__ import print_function +This Source Code Form is subject to the terms of the Mozilla Public +License, v. 2.0. If a copy of the MPL was not distributed with this +file, you can obtain one at https://mozilla.org/MPL/2.0/. -import datetime -import os -import select -import signal -import socket -import sys -import time -import threading -import struct +See the COPYRIGHT file distributed with this work for additional +information regarding copyright ownership. +""" -DELAY = 0.5 -THREADS = [] +from isctest.asyncserver import AsyncDnsServer, ForwarderHandler -def log(msg): - print(datetime.datetime.now().strftime("%d-%b-%Y %H:%M:%S.%f ") + msg) +class ForwardToNs2(ForwarderHandler): + target = "10.53.0.2" + delay = 0.5 -def sigterm(*_): - log("SIGTERM received, shutting down") - for thread in THREADS: - thread.close() - thread.join() - os.remove("ans.pid") - sys.exit(0) - - -class TCPDelayer(threading.Thread): - """For a given TCP connection conn we open a connection to (ip, port), - and then we delay each incoming packet by DELAY by putting it in a - queue. - In the pipelined test TCP should not be used, but it's here for - completnes. - """ - - def __init__(self, conn, ip, port): - threading.Thread.__init__(self) - self.conn = conn - self.cconn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.cconn.connect((ip, port)) - self.queue = [] - self.running = True - - def close(self): - self.running = False - - def run(self): - while self.running: - curr_timeout = 0.5 - try: - curr_timeout = self.queue[0][0] - time.monotonic() - except StopIteration: - pass - if curr_timeout > 0: - if curr_timeout == 0: - curr_timeout = 0.5 - rfds, _, _ = select.select( - [self.conn, self.cconn], [], [], curr_timeout - ) - if self.conn in rfds: - data = self.conn.recv(65535) - if not data: - return - self.queue.append((time.monotonic() + DELAY, data)) - if self.cconn in rfds: - data = self.cconn.recv(65535) - if not data == 0: - return - self.conn.send(data) - try: - while self.queue[0][0] - time.monotonic() < 0: - _, data = self.queue.pop(0) - self.cconn.send(data) - except StopIteration: - pass - - -class UDPDelayer(threading.Thread): - """Every incoming UDP packet is put in a queue for DELAY time, then - it's sent to (ip, port). We remember the query id to send the - response we get to a proper source, responses are not delayed. - """ - - def __init__(self, usock, ip, port): - threading.Thread.__init__(self) - self.sock = usock - self.csock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.dst = (ip, port) - self.queue = [] - self.qid_mapping = {} - self.running = True - - def close(self): - self.running = False - - def run(self): - while self.running: - curr_timeout = 0.5 - if self.queue: - curr_timeout = self.queue[0][0] - time.monotonic() - if curr_timeout >= 0: - if curr_timeout == 0: - curr_timeout = 0.5 - rfds, _, _ = select.select( - [self.sock, self.csock], [], [], curr_timeout - ) - if self.sock in rfds: - data, addr = self.sock.recvfrom(65535) - if not data: - return - self.queue.append((time.monotonic() + DELAY, data)) - qid = struct.unpack(">H", data[:2])[0] - log("Received a query from %s, queryid %d" % (str(addr), qid)) - self.qid_mapping[qid] = addr - if self.csock in rfds: - data, addr = self.csock.recvfrom(65535) - if not data: - return - qid = struct.unpack(">H", data[:2])[0] - dst = self.qid_mapping.get(qid) - if dst is not None: - self.sock.sendto(data, dst) - log( - "Received a response from %s, queryid %d, sending to %s" - % (str(addr), qid, str(dst)) - ) - while self.queue and self.queue[0][0] - time.monotonic() < 0: - _, data = self.queue.pop(0) - qid = struct.unpack(">H", data[:2])[0] - log("Sending a query to %s, queryid %d" % (str(self.dst), qid)) - self.csock.sendto(data, self.dst) - - -def main(): - signal.signal(signal.SIGTERM, sigterm) - signal.signal(signal.SIGINT, sigterm) - - with open("ans.pid", "w") as pidfile: - print(os.getpid(), file=pidfile) - - listenip = "10.53.0.5" - serverip = "10.53.0.2" - - try: - port = int(os.environ["PORT"]) - except KeyError: - port = 5300 - - log("Listening on %s:%d" % (listenip, port)) - - usock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - usock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - usock.bind((listenip, port)) - thread = UDPDelayer(usock, serverip, port) - thread.start() - THREADS.append(thread) - - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind((listenip, port)) - sock.listen(1) - sock.settimeout(1) - - while True: - try: - clientsock, _ = sock.accept() - log("Accepted connection from %s" % clientsock) - thread = TCPDelayer(clientsock, serverip, port) - thread.start() - THREADS.append(thread) - except socket.timeout: - pass +def main() -> None: + server = AsyncDnsServer() + server.install_response_handlers(ForwardToNs2()) + server.run() if __name__ == "__main__":