mirror of
https://github.com/isc-projects/bind9.git
synced 2026-06-03 22:08:25 -04:00
chg: test: asyncserver.py: TCP improvements
This branch started off as `michal/upforwd-asyncserver`. It quickly turned out that the critical `asyncserver.py` change that was needed for the `upforwd` system test was for the server to be able to read multiple TCP queries on a single connection. As currently present in `main`, `asyncserver.py` closes every client connection after servicing a single query. Retaining that behavior would cause the `upforwd` system test to fail and, in general, capturing all data sent by a client seems more useful in tests than just closing connections quickly. `asyncserver.py` can always be extended in the future (e.g. by adding a new `ResponseAction` that the networking code would react to) to reinstate the original behavior, if it turns out to be necessary. While working on changing that particular `asyncserver.py` behavior, I noticed a couple of other deficiencies in the TCP connection handling code, so I started addressing them. One thing led to another and before I noticed, enough changes were applied to be worth doing a separate merge request, particularly given that the actual rewrite of `upforwd/ans4/ans.pl` using `asyncserver.py` is trivial once the required changes to `asyncserver.py` itself are applied. Merge branch 'michal/asyncserver-tcp-improvements' into 'main' See merge request isc-projects/bind9!10276
This commit is contained in:
commit
c6e5710846
1 changed files with 100 additions and 32 deletions
|
|
@ -224,6 +224,20 @@ class DnsProtocol(enum.Enum):
|
|||
TCP = enum.auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Peer:
|
||||
"""
|
||||
Pretty-printed connection endpoint.
|
||||
"""
|
||||
|
||||
host: str
|
||||
port: int
|
||||
|
||||
def __str__(self) -> str:
|
||||
host = f"[{self.host}]" if ":" in self.host else self.host
|
||||
return f"{host}:{self.port}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryContext:
|
||||
"""
|
||||
|
|
@ -232,7 +246,7 @@ class QueryContext:
|
|||
|
||||
query: dns.message.Message
|
||||
response: dns.message.Message
|
||||
peer: Tuple[str, int]
|
||||
peer: Peer
|
||||
protocol: DnsProtocol
|
||||
zone: Optional[dns.zone.Zone] = None
|
||||
soa: Optional[dns.rrset.RRset] = None
|
||||
|
|
@ -513,56 +527,110 @@ class AsyncDnsServer(AsyncServer):
|
|||
self._zone_tree.add(zone)
|
||||
|
||||
async def _handle_udp(
|
||||
self, wire: bytes, peer: Tuple[str, int], transport: asyncio.DatagramTransport
|
||||
self, wire: bytes, addr: Tuple[str, int], transport: asyncio.DatagramTransport
|
||||
) -> None:
|
||||
logging.debug("Received UDP message: %s", wire.hex())
|
||||
peer = Peer(addr[0], addr[1])
|
||||
responses = self._handle_query(wire, peer, DnsProtocol.UDP)
|
||||
async for response in responses:
|
||||
transport.sendto(response, peer)
|
||||
transport.sendto(response, addr)
|
||||
|
||||
async def _handle_tcp(
|
||||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
wire_length_bytes = await reader.read(2)
|
||||
(wire_length,) = struct.unpack("!H", wire_length_bytes)
|
||||
logging.debug("Receiving TCP message (%d octets)...", wire_length)
|
||||
peer_info = writer.get_extra_info("peername")
|
||||
peer = Peer(peer_info[0], peer_info[1])
|
||||
logging.debug("Accepted TCP connection from %s", peer)
|
||||
|
||||
wire = await reader.read(wire_length)
|
||||
full_message = wire_length_bytes + wire
|
||||
logging.debug("Received complete TCP message: %s", full_message.hex())
|
||||
|
||||
peer = writer.get_extra_info("peername")
|
||||
responses = self._handle_query(wire, peer, DnsProtocol.TCP)
|
||||
async for response in responses:
|
||||
writer.write(response)
|
||||
while True:
|
||||
try:
|
||||
await writer.drain()
|
||||
wire = await self._read_tcp_query(reader, peer)
|
||||
if not wire:
|
||||
break
|
||||
await self._send_tcp_response(writer, peer, wire)
|
||||
except ConnectionResetError:
|
||||
logging.error(
|
||||
"TCP connection from %s reset by peer", self._format_peer(peer)
|
||||
)
|
||||
logging.error("TCP connection from %s reset by peer", peer)
|
||||
return
|
||||
|
||||
logging.debug("Closing TCP connection from %s", peer)
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
def _format_peer(self, peer: Tuple[str, int]) -> str:
|
||||
host = peer[0]
|
||||
port = peer[1]
|
||||
if "::" in host:
|
||||
host = f"[{host}]"
|
||||
return f"{host}:{port}"
|
||||
async def _read_tcp_query(
|
||||
self, reader: asyncio.StreamReader, peer: Peer
|
||||
) -> Optional[bytes]:
|
||||
wire_length = await self._read_tcp_query_wire_length(reader, peer)
|
||||
if not wire_length:
|
||||
return None
|
||||
|
||||
def _log_query(
|
||||
self, qctx: QueryContext, peer: Tuple[str, int], protocol: DnsProtocol
|
||||
return await self._read_tcp_query_wire(reader, peer, wire_length)
|
||||
|
||||
async def _read_tcp_query_wire_length(
|
||||
self, reader: asyncio.StreamReader, peer: Peer
|
||||
) -> Optional[int]:
|
||||
logging.debug("Receiving TCP message length from %s...", peer)
|
||||
|
||||
wire_length_bytes = await self._read_tcp_octets(reader, peer, 2)
|
||||
if not wire_length_bytes:
|
||||
return None
|
||||
|
||||
(wire_length,) = struct.unpack("!H", wire_length_bytes)
|
||||
|
||||
return wire_length
|
||||
|
||||
async def _read_tcp_query_wire(
|
||||
self, reader: asyncio.StreamReader, peer: Peer, wire_length: int
|
||||
) -> Optional[bytes]:
|
||||
logging.debug("Receiving TCP message (%d octets) from %s...", wire_length, peer)
|
||||
|
||||
wire = await self._read_tcp_octets(reader, peer, wire_length)
|
||||
if not wire:
|
||||
return None
|
||||
|
||||
logging.debug("Received complete TCP message from %s: %s", peer, wire.hex())
|
||||
|
||||
return wire
|
||||
|
||||
async def _read_tcp_octets(
|
||||
self, reader: asyncio.StreamReader, peer: Peer, expected: int
|
||||
) -> Optional[bytes]:
|
||||
buffer = b""
|
||||
|
||||
while len(buffer) < expected:
|
||||
chunk = await reader.read(expected - len(buffer))
|
||||
if not chunk:
|
||||
if buffer:
|
||||
logging.debug(
|
||||
"Received short TCP message (%d octets) from %s: %s",
|
||||
len(buffer),
|
||||
peer,
|
||||
buffer.hex(),
|
||||
)
|
||||
else:
|
||||
logging.debug("Received disconnect from %s", peer)
|
||||
return None
|
||||
|
||||
logging.debug("Received %d TCP octets from %s", len(chunk), peer)
|
||||
buffer += chunk
|
||||
|
||||
return buffer
|
||||
|
||||
async def _send_tcp_response(
|
||||
self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes
|
||||
) -> None:
|
||||
responses = self._handle_query(wire, peer, DnsProtocol.TCP)
|
||||
async for response in responses:
|
||||
writer.write(response)
|
||||
await writer.drain()
|
||||
|
||||
def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None:
|
||||
logging.info(
|
||||
"Received %s/%s/%s (ID=%d) query from %s (%s)",
|
||||
qctx.qname.to_text(omit_final_dot=True),
|
||||
dns.rdataclass.to_text(qctx.qclass),
|
||||
dns.rdatatype.to_text(qctx.qtype),
|
||||
qctx.query.id,
|
||||
self._format_peer(peer),
|
||||
peer,
|
||||
protocol.name,
|
||||
)
|
||||
logging.debug(
|
||||
|
|
@ -573,14 +641,14 @@ class AsyncDnsServer(AsyncServer):
|
|||
self,
|
||||
qctx: QueryContext,
|
||||
response: Optional[Union[dns.message.Message, bytes]],
|
||||
peer: Tuple[str, int],
|
||||
peer: Peer,
|
||||
protocol: DnsProtocol,
|
||||
) -> None:
|
||||
if not response:
|
||||
logging.info(
|
||||
"Not sending a response to query (ID=%d) from %s (%s)",
|
||||
qctx.query.id,
|
||||
self._format_peer(peer),
|
||||
peer,
|
||||
protocol.name,
|
||||
)
|
||||
return
|
||||
|
|
@ -606,7 +674,7 @@ class AsyncDnsServer(AsyncServer):
|
|||
len(response.authority),
|
||||
len(response.additional),
|
||||
qctx.query.id,
|
||||
self._format_peer(peer),
|
||||
peer,
|
||||
protocol.name,
|
||||
)
|
||||
logging.debug(
|
||||
|
|
@ -618,13 +686,13 @@ class AsyncDnsServer(AsyncServer):
|
|||
"Sending response (%d bytes) to a query (ID=%d) from %s (%s)",
|
||||
len(response),
|
||||
qctx.query.id,
|
||||
self._format_peer(peer),
|
||||
peer,
|
||||
protocol.name,
|
||||
)
|
||||
logging.debug("[OUT] %s", response.hex())
|
||||
|
||||
async def _handle_query(
|
||||
self, wire: bytes, peer: Tuple[str, int], protocol: DnsProtocol
|
||||
self, wire: bytes, peer: Peer, protocol: DnsProtocol
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
Yield wire data to send as a response over the established transport.
|
||||
|
|
|
|||
Loading…
Reference in a new issue