diff --git a/bin/tests/system/isctest/asyncserver.py b/bin/tests/system/isctest/asyncserver.py index ab508b404a..b4270fa899 100644 --- a/bin/tests/system/isctest/asyncserver.py +++ b/bin/tests/system/isctest/asyncserver.py @@ -224,6 +224,20 @@ class DnsProtocol(enum.Enum): TCP = enum.auto() +@dataclass(frozen=True) +class Peer: + """ + Pretty-printed connection endpoint. + """ + + host: str + port: int + + def __str__(self) -> str: + host = f"[{self.host}]" if ":" in self.host else self.host + return f"{host}:{self.port}" + + @dataclass class QueryContext: """ @@ -232,7 +246,7 @@ class QueryContext: query: dns.message.Message response: dns.message.Message - peer: Tuple[str, int] + peer: Peer protocol: DnsProtocol zone: Optional[dns.zone.Zone] = None soa: Optional[dns.rrset.RRset] = None @@ -513,56 +527,110 @@ class AsyncDnsServer(AsyncServer): self._zone_tree.add(zone) async def _handle_udp( - self, wire: bytes, peer: Tuple[str, int], transport: asyncio.DatagramTransport + self, wire: bytes, addr: Tuple[str, int], transport: asyncio.DatagramTransport ) -> None: logging.debug("Received UDP message: %s", wire.hex()) + peer = Peer(addr[0], addr[1]) responses = self._handle_query(wire, peer, DnsProtocol.UDP) async for response in responses: - transport.sendto(response, peer) + transport.sendto(response, addr) async def _handle_tcp( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: - wire_length_bytes = await reader.read(2) - (wire_length,) = struct.unpack("!H", wire_length_bytes) - logging.debug("Receiving TCP message (%d octets)...", wire_length) + peer_info = writer.get_extra_info("peername") + peer = Peer(peer_info[0], peer_info[1]) + logging.debug("Accepted TCP connection from %s", peer) - wire = await reader.read(wire_length) - full_message = wire_length_bytes + wire - logging.debug("Received complete TCP message: %s", full_message.hex()) - - peer = writer.get_extra_info("peername") - responses = self._handle_query(wire, peer, DnsProtocol.TCP) - async for response in responses: - writer.write(response) + while True: try: - await writer.drain() + wire = await self._read_tcp_query(reader, peer) + if not wire: + break + await self._send_tcp_response(writer, peer, wire) except ConnectionResetError: - logging.error( - "TCP connection from %s reset by peer", self._format_peer(peer) - ) + logging.error("TCP connection from %s reset by peer", peer) return + logging.debug("Closing TCP connection from %s", peer) writer.close() await writer.wait_closed() - def _format_peer(self, peer: Tuple[str, int]) -> str: - host = peer[0] - port = peer[1] - if "::" in host: - host = f"[{host}]" - return f"{host}:{port}" + async def _read_tcp_query( + self, reader: asyncio.StreamReader, peer: Peer + ) -> Optional[bytes]: + wire_length = await self._read_tcp_query_wire_length(reader, peer) + if not wire_length: + return None - def _log_query( - self, qctx: QueryContext, peer: Tuple[str, int], protocol: DnsProtocol + return await self._read_tcp_query_wire(reader, peer, wire_length) + + async def _read_tcp_query_wire_length( + self, reader: asyncio.StreamReader, peer: Peer + ) -> Optional[int]: + logging.debug("Receiving TCP message length from %s...", peer) + + wire_length_bytes = await self._read_tcp_octets(reader, peer, 2) + if not wire_length_bytes: + return None + + (wire_length,) = struct.unpack("!H", wire_length_bytes) + + return wire_length + + async def _read_tcp_query_wire( + self, reader: asyncio.StreamReader, peer: Peer, wire_length: int + ) -> Optional[bytes]: + logging.debug("Receiving TCP message (%d octets) from %s...", wire_length, peer) + + wire = await self._read_tcp_octets(reader, peer, wire_length) + if not wire: + return None + + logging.debug("Received complete TCP message from %s: %s", peer, wire.hex()) + + return wire + + async def _read_tcp_octets( + self, reader: asyncio.StreamReader, peer: Peer, expected: int + ) -> Optional[bytes]: + buffer = b"" + + while len(buffer) < expected: + chunk = await reader.read(expected - len(buffer)) + if not chunk: + if buffer: + logging.debug( + "Received short TCP message (%d octets) from %s: %s", + len(buffer), + peer, + buffer.hex(), + ) + else: + logging.debug("Received disconnect from %s", peer) + return None + + logging.debug("Received %d TCP octets from %s", len(chunk), peer) + buffer += chunk + + return buffer + + async def _send_tcp_response( + self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes ) -> None: + responses = self._handle_query(wire, peer, DnsProtocol.TCP) + async for response in responses: + writer.write(response) + await writer.drain() + + def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None: logging.info( "Received %s/%s/%s (ID=%d) query from %s (%s)", qctx.qname.to_text(omit_final_dot=True), dns.rdataclass.to_text(qctx.qclass), dns.rdatatype.to_text(qctx.qtype), qctx.query.id, - self._format_peer(peer), + peer, protocol.name, ) logging.debug( @@ -573,14 +641,14 @@ class AsyncDnsServer(AsyncServer): self, qctx: QueryContext, response: Optional[Union[dns.message.Message, bytes]], - peer: Tuple[str, int], + peer: Peer, protocol: DnsProtocol, ) -> None: if not response: logging.info( "Not sending a response to query (ID=%d) from %s (%s)", qctx.query.id, - self._format_peer(peer), + peer, protocol.name, ) return @@ -606,7 +674,7 @@ class AsyncDnsServer(AsyncServer): len(response.authority), len(response.additional), qctx.query.id, - self._format_peer(peer), + peer, protocol.name, ) logging.debug( @@ -618,13 +686,13 @@ class AsyncDnsServer(AsyncServer): "Sending response (%d bytes) to a query (ID=%d) from %s (%s)", len(response), qctx.query.id, - self._format_peer(peer), + peer, protocol.name, ) logging.debug("[OUT] %s", response.hex()) async def _handle_query( - self, wire: bytes, peer: Tuple[str, int], protocol: DnsProtocol + self, wire: bytes, peer: Peer, protocol: DnsProtocol ) -> AsyncGenerator[bytes, None]: """ Yield wire data to send as a response over the established transport.