Implement a response handler that forwards queries

Add a new response handler, ForwarderHandler, which enables forwarding
all queries to another DNS server.  To simplify implementation, always
forward queries to the target server via UDP, even if they are
originally received using a different transport protocol.

(cherry picked from commit 10a2fc7f1f)
This commit is contained in:
Michał Kępień 2026-02-13 14:27:10 +01:00 committed by Michał Kępień (GitLab job 6866004)
parent f773a18f40
commit 5a0e1de2e5

View file

@ -788,6 +788,108 @@ class DomainHandler(ResponseHandler):
return False
class ForwarderHandler(ResponseHandler):
"""
A handler forwarding all received queries to another DNS server with an
optional delay and then relaying the responses back to the original client.
Queries are currently always forwarded via UDP.
"""
@property
@abc.abstractmethod
def target(self) -> str:
"""
The address of the DNS server to forward queries to.
"""
raise NotImplementedError
@property
def port(self) -> int:
"""
The port of the DNS server to forward queries to.
The default value of 0 causes the same port as the one used by this
server for listening to be used.
"""
return 0
@property
def delay(self) -> float:
"""
The number of seconds to wait before forwarding each query.
"""
return 0.0
def __str__(self) -> str:
return f"{self.__class__.__name__}(target: {self.target}:{self.port})"
class ForwarderProtocol(asyncio.DatagramProtocol):
def __init__(self, query: bytes, response: asyncio.Future) -> None:
self._query = query
self._response = response
def connection_made(self, transport: asyncio.BaseTransport) -> None:
logging.debug("[OUT] %s", self._query.hex())
cast(asyncio.DatagramTransport, transport).sendto(self._query)
def datagram_received(self, data: bytes, _: Tuple[str, int]) -> None:
logging.debug("[IN] %s", data.hex())
self._response.set_result(data)
async def get_responses(
self, qctx: QueryContext
) -> AsyncGenerator[ResponseAction, None]:
loop = asyncio.get_running_loop()
response = loop.create_future()
forwarding_target = f"{self.target}:{self.port or qctx.socket.port}"
if self.delay > 0:
logging.info(
"Waiting %.1fs before forwarding %s query from %s to %s over UDP",
self.delay,
qctx.protocol.name,
qctx.peer,
forwarding_target,
)
await asyncio.sleep(self.delay)
logging.info(
"Forwarding %s query from %s to %s over UDP",
qctx.protocol.name,
qctx.peer,
forwarding_target,
)
transport, _ = await loop.create_datagram_endpoint(
lambda: self.ForwarderProtocol(qctx.query.to_wire(), response),
local_addr=(qctx.socket.host, 0),
remote_addr=(self.target, self.port or qctx.socket.port),
)
try:
await response
finally:
transport.close()
logging.info(
"Relaying UDP response from %s to %s over %s",
forwarding_target,
qctx.peer,
qctx.protocol.name,
)
try:
message = _DnsMessageWithTsigDisabled.from_wire(response.result())
yield DnsResponseSend(message, acknowledge_hand_rolled_response=True)
except dns.exception.DNSException:
logging.warning(
"Failed to parse response from %s as a DNS message, relaying it as raw bytes",
forwarding_target,
)
yield BytesResponseSend(response.result())
@dataclass
class _ZoneTreeNode:
"""