Enable receiving chunked TCP DNS messages

A TCP DNS client may send its queries in chunks, causing
StreamReader.read() to return less data than previously declared by the
client as the DNS message length; even the two-octet DNS message length
itself may be split up into two single-octet transmissions.  Sending
data in chunks is valid client behavior that should not be treated as an
error.  Add a new helper method for reading TCP data in a loop, properly
distinguishing between chunked queries and client disconnections.  Use
the new method for reading all TCP data from clients.

(cherry picked from commit 68fe9a5df5)
This commit is contained in:
Michał Kępień 2025-03-18 16:28:18 +01:00
parent 3d80d9778b
commit 96766f3d29

View file

@ -570,8 +570,8 @@ class AsyncDnsServer(AsyncServer):
) -> Optional[int]:
logging.debug("Receiving TCP message length from %s...", peer)
wire_length_bytes = await reader.read(2)
if len(wire_length_bytes) < 2:
wire_length_bytes = await self._read_tcp_octets(reader, peer, 2)
if not wire_length_bytes:
return None
(wire_length,) = struct.unpack("!H", wire_length_bytes)
@ -583,14 +583,38 @@ class AsyncDnsServer(AsyncServer):
) -> Optional[bytes]:
logging.debug("Receiving TCP message (%d octets) from %s...", wire_length, peer)
wire = await reader.read(wire_length)
if len(wire) < wire_length:
wire = await self._read_tcp_octets(reader, peer, wire_length)
if not wire:
return None
logging.debug("Received complete TCP message from %s: %s", peer, wire.hex())
return wire
async def _read_tcp_octets(
self, reader: asyncio.StreamReader, peer: Peer, expected: int
) -> Optional[bytes]:
buffer = b""
while len(buffer) < expected:
chunk = await reader.read(expected - len(buffer))
if not chunk:
if buffer:
logging.debug(
"Received short TCP message (%d octets) from %s: %s",
len(buffer),
peer,
buffer.hex(),
)
else:
logging.debug("Received disconnect from %s", peer)
return None
logging.debug("Received %d TCP octets from %s", len(chunk), peer)
buffer += chunk
return buffer
async def _send_tcp_response(
self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes
) -> None: