chg: test: asyncserver.py: TCP improvements

This branch started off as `michal/upforwd-asyncserver`.  It quickly
turned out that the critical `asyncserver.py` change that was needed for
the `upforwd` system test was for the server to be able to read multiple
TCP queries on a single connection.  As currently present in `main`,
`asyncserver.py` closes every client connection after servicing a single
query.  Retaining that behavior would cause the `upforwd` system test to
fail and, in general, capturing all data sent by a client seems more
useful in tests than just closing connections quickly.  `asyncserver.py`
can always be extended in the future (e.g. by adding a new
`ResponseAction` that the networking code would react to) to reinstate
the original behavior, if it turns out to be necessary.

While working on changing that particular `asyncserver.py` behavior, I
noticed a couple of other deficiencies in the TCP connection handling
code, so I started addressing them.  One thing led to another and before
I noticed, enough changes were applied to be worth doing a separate
merge request, particularly given that the actual rewrite of
`upforwd/ans4/ans.pl` using `asyncserver.py` is trivial once the
required changes to `asyncserver.py` itself are applied.

Merge branch 'michal/asyncserver-tcp-improvements' into 'main'

See merge request isc-projects/bind9!10276
This commit is contained in:
Michał Kępień 2025-03-18 15:30:35 +00:00
commit c6e5710846

View file

@ -224,6 +224,20 @@ class DnsProtocol(enum.Enum):
TCP = enum.auto()
@dataclass(frozen=True)
class Peer:
"""
Pretty-printed connection endpoint.
"""
host: str
port: int
def __str__(self) -> str:
host = f"[{self.host}]" if ":" in self.host else self.host
return f"{host}:{self.port}"
@dataclass
class QueryContext:
"""
@ -232,7 +246,7 @@ class QueryContext:
query: dns.message.Message
response: dns.message.Message
peer: Tuple[str, int]
peer: Peer
protocol: DnsProtocol
zone: Optional[dns.zone.Zone] = None
soa: Optional[dns.rrset.RRset] = None
@ -513,56 +527,110 @@ class AsyncDnsServer(AsyncServer):
self._zone_tree.add(zone)
async def _handle_udp(
self, wire: bytes, peer: Tuple[str, int], transport: asyncio.DatagramTransport
self, wire: bytes, addr: Tuple[str, int], transport: asyncio.DatagramTransport
) -> None:
logging.debug("Received UDP message: %s", wire.hex())
peer = Peer(addr[0], addr[1])
responses = self._handle_query(wire, peer, DnsProtocol.UDP)
async for response in responses:
transport.sendto(response, peer)
transport.sendto(response, addr)
async def _handle_tcp(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
wire_length_bytes = await reader.read(2)
(wire_length,) = struct.unpack("!H", wire_length_bytes)
logging.debug("Receiving TCP message (%d octets)...", wire_length)
peer_info = writer.get_extra_info("peername")
peer = Peer(peer_info[0], peer_info[1])
logging.debug("Accepted TCP connection from %s", peer)
wire = await reader.read(wire_length)
full_message = wire_length_bytes + wire
logging.debug("Received complete TCP message: %s", full_message.hex())
peer = writer.get_extra_info("peername")
responses = self._handle_query(wire, peer, DnsProtocol.TCP)
async for response in responses:
writer.write(response)
while True:
try:
await writer.drain()
wire = await self._read_tcp_query(reader, peer)
if not wire:
break
await self._send_tcp_response(writer, peer, wire)
except ConnectionResetError:
logging.error(
"TCP connection from %s reset by peer", self._format_peer(peer)
)
logging.error("TCP connection from %s reset by peer", peer)
return
logging.debug("Closing TCP connection from %s", peer)
writer.close()
await writer.wait_closed()
def _format_peer(self, peer: Tuple[str, int]) -> str:
host = peer[0]
port = peer[1]
if "::" in host:
host = f"[{host}]"
return f"{host}:{port}"
async def _read_tcp_query(
self, reader: asyncio.StreamReader, peer: Peer
) -> Optional[bytes]:
wire_length = await self._read_tcp_query_wire_length(reader, peer)
if not wire_length:
return None
def _log_query(
self, qctx: QueryContext, peer: Tuple[str, int], protocol: DnsProtocol
return await self._read_tcp_query_wire(reader, peer, wire_length)
async def _read_tcp_query_wire_length(
self, reader: asyncio.StreamReader, peer: Peer
) -> Optional[int]:
logging.debug("Receiving TCP message length from %s...", peer)
wire_length_bytes = await self._read_tcp_octets(reader, peer, 2)
if not wire_length_bytes:
return None
(wire_length,) = struct.unpack("!H", wire_length_bytes)
return wire_length
async def _read_tcp_query_wire(
self, reader: asyncio.StreamReader, peer: Peer, wire_length: int
) -> Optional[bytes]:
logging.debug("Receiving TCP message (%d octets) from %s...", wire_length, peer)
wire = await self._read_tcp_octets(reader, peer, wire_length)
if not wire:
return None
logging.debug("Received complete TCP message from %s: %s", peer, wire.hex())
return wire
async def _read_tcp_octets(
self, reader: asyncio.StreamReader, peer: Peer, expected: int
) -> Optional[bytes]:
buffer = b""
while len(buffer) < expected:
chunk = await reader.read(expected - len(buffer))
if not chunk:
if buffer:
logging.debug(
"Received short TCP message (%d octets) from %s: %s",
len(buffer),
peer,
buffer.hex(),
)
else:
logging.debug("Received disconnect from %s", peer)
return None
logging.debug("Received %d TCP octets from %s", len(chunk), peer)
buffer += chunk
return buffer
async def _send_tcp_response(
self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes
) -> None:
responses = self._handle_query(wire, peer, DnsProtocol.TCP)
async for response in responses:
writer.write(response)
await writer.drain()
def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None:
logging.info(
"Received %s/%s/%s (ID=%d) query from %s (%s)",
qctx.qname.to_text(omit_final_dot=True),
dns.rdataclass.to_text(qctx.qclass),
dns.rdatatype.to_text(qctx.qtype),
qctx.query.id,
self._format_peer(peer),
peer,
protocol.name,
)
logging.debug(
@ -573,14 +641,14 @@ class AsyncDnsServer(AsyncServer):
self,
qctx: QueryContext,
response: Optional[Union[dns.message.Message, bytes]],
peer: Tuple[str, int],
peer: Peer,
protocol: DnsProtocol,
) -> None:
if not response:
logging.info(
"Not sending a response to query (ID=%d) from %s (%s)",
qctx.query.id,
self._format_peer(peer),
peer,
protocol.name,
)
return
@ -606,7 +674,7 @@ class AsyncDnsServer(AsyncServer):
len(response.authority),
len(response.additional),
qctx.query.id,
self._format_peer(peer),
peer,
protocol.name,
)
logging.debug(
@ -618,13 +686,13 @@ class AsyncDnsServer(AsyncServer):
"Sending response (%d bytes) to a query (ID=%d) from %s (%s)",
len(response),
qctx.query.id,
self._format_peer(peer),
peer,
protocol.name,
)
logging.debug("[OUT] %s", response.hex())
async def _handle_query(
self, wire: bytes, peer: Tuple[str, int], protocol: DnsProtocol
self, wire: bytes, peer: Peer, protocol: DnsProtocol
) -> AsyncGenerator[bytes, None]:
"""
Yield wire data to send as a response over the established transport.