Store server socket information in QueryContext

Extend the QueryContext class with a field holding the <address, port>
tuple for the socket on which a given query was received.  This will
enable query handlers to act upon that information in arbitrary ways.

(cherry picked from commit 94a4793596)
This commit is contained in:
Michał Kępień 2026-02-13 14:27:10 +01:00 committed by Michał Kępień (GitLab job 6866004)
parent 66c58ce793
commit 442285dce3

View file

@ -266,6 +266,7 @@ class QueryContext:
query: dns.message.Message
response: dns.message.Message
socket: Peer
peer: Peer
protocol: DnsProtocol
zone: Optional[dns.zone.Zone] = field(default=None, init=False)
@ -1072,8 +1073,10 @@ class AsyncDnsServer(AsyncServer):
self, wire: bytes, addr: Tuple[str, int], transport: asyncio.DatagramTransport
) -> None:
logging.debug("Received UDP message: %s", wire.hex())
socket_info = transport.get_extra_info("sockname")
socket = Peer(socket_info[0], socket_info[1])
peer = Peer(addr[0], addr[1])
responses = self._handle_query(wire, peer, DnsProtocol.UDP)
responses = self._handle_query(wire, socket, peer, DnsProtocol.UDP)
async for response in responses:
logging.debug("Sending UDP message: %s", response.hex())
transport.sendto(response, addr)
@ -1170,7 +1173,9 @@ class AsyncDnsServer(AsyncServer):
async def _send_tcp_response(
self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes
) -> None:
responses = self._handle_query(wire, peer, DnsProtocol.TCP)
socket_info = writer.get_extra_info("sockname")
socket = Peer(socket_info[0], socket_info[1])
responses = self._handle_query(wire, socket, peer, DnsProtocol.TCP)
async for response in responses:
logging.debug("Sending TCP response: %s", response.hex())
writer.write(response)
@ -1245,7 +1250,7 @@ class AsyncDnsServer(AsyncServer):
logging.debug("[OUT] %s", response.hex())
async def _handle_query(
self, wire: bytes, peer: Peer, protocol: DnsProtocol
self, wire: bytes, socket: Peer, peer: Peer, protocol: DnsProtocol
) -> AsyncGenerator[bytes, None]:
"""
Yield wire data to send as a response over the established transport.
@ -1256,7 +1261,7 @@ class AsyncDnsServer(AsyncServer):
logging.error("Invalid query from %s (%s): %s", peer, wire.hex(), exc)
return
response_stub = _make_asyncserver_response(query)
qctx = QueryContext(query, response_stub, peer, protocol)
qctx = QueryContext(query, response_stub, socket, peer, protocol)
self._log_query(qctx, peer, protocol)
responses = self._prepare_responses(qctx)
async for response in responses: