Let queries with TSIG parse in isctest.asyncserver.AsyncDnsServer

Previously, upon receiving a query with TSIG, the server would log
an error and timeout. As there is no way to set up the keyring in the
class anyway (and I believe we don't need it), this commit lets such
queries parse but logs the fact that the query has TSIG.

However, there is a bug [1] in dnspython, which causes `make_response`
and `to_wire` to crash on messages constructed by `from_wire` with
`keyring=False`, so the hack with `message.__class__` is needed to work
around this.

This makes just enough changes for the tsig system test to work with
dnspython >= 2.0.0. On older version the server gives up.

[1] https://github.com/rthalley/dnspython/issues/1205

(cherry picked from commit 72ac1fe234)
This commit is contained in:
Štěpán Balážik 2025-06-23 16:43:56 +02:00
parent b7e7923daa
commit 58571d588f

View file

@ -28,6 +28,7 @@ from typing import (
import abc
import asyncio
import contextlib
import enum
import functools
import logging
@ -46,6 +47,8 @@ import dns.rcode
import dns.rdataclass
import dns.rdatatype
import dns.rrset
import dns.tsig
import dns.version
import dns.zone
try:
@ -517,6 +520,67 @@ class _ZoneTree:
return node.zone if node != self._root else None
class _DnsMessageWithTsigDisabled(dns.message.Message):
"""
A wrapper for `dns.message.Message` that works around a dnspython bug
causing exceptions to be raised when `make_response()` or `to_wire()` are
called for a message created using `dns.message.from_wire(keyring=False)`.
See https://github.com/rthalley/dnspython/issues/1205 for more details.
"""
class _DisableTsigHandling(contextlib.ContextDecorator):
def __init__(self, message: Optional[dns.message.Message] = None) -> None:
self.original_tsig_sign = dns.tsig.sign
self.original_tsig_validate = dns.tsig.validate
if message:
self.tsig = message.tsig
def __enter__(self) -> None:
"""
Override the `dns.tsig.sign` and `dns.tsig.validate` functions to prevent them
from failing on messages initialized with `dns.message.from_wire(keyring=False)`.
"""
def sign(*_: Any, **__: Any) -> Tuple[dns.rdata.Rdata, None]:
assert self.tsig
return self.tsig[0], None
def validate(*_: Any, **__: Any) -> None:
return None
dns.tsig.sign = sign
dns.tsig.validate = validate
def __exit__(self, *_: Any, **__: Any) -> None:
dns.tsig.sign = self.original_tsig_sign
dns.tsig.validate = self.original_tsig_validate
@classmethod
def from_wire(cls, wire: bytes) -> "_DnsMessageWithTsigDisabled":
with cls._DisableTsigHandling():
message = dns.message.from_wire(wire, keyring=False)
message.__class__ = _DnsMessageWithTsigDisabled
return cast(_DnsMessageWithTsigDisabled, message)
@property
def had_tsig(self) -> bool:
"""
Override the `had_tsig()` method to always return False, to prevent
`make_response()` from crashing.
"""
return False
def to_wire(self, *args: Any, **kwargs: Any) -> bytes:
"""
Override the `to_wire()` method to prevent it from trying to sign
the message with TSIG.
"""
with self._DisableTsigHandling(self):
return super().to_wire(*args, **kwargs)
class AsyncDnsServer(AsyncServer):
"""
DNS server which responds to queries based on zone data and/or custom
@ -533,12 +597,17 @@ class AsyncDnsServer(AsyncServer):
response from scratch, without using zone data at all.
"""
def __init__(self, acknowledge_manual_dname_handling: bool = False) -> None:
def __init__(
self,
acknowledge_manual_dname_handling: bool = False,
acknowledge_tsig_dnspython_hacks: bool = False,
) -> None:
super().__init__(self._handle_udp, self._handle_tcp, "ans.pid")
self._zone_tree: _ZoneTree = _ZoneTree()
self._response_handlers: List[ResponseHandler] = []
self._acknowledge_manual_dname_handling = acknowledge_manual_dname_handling
self._acknowledge_tsig_dnspython_hacks = acknowledge_tsig_dnspython_hacks
self._load_zones()
@ -778,6 +847,10 @@ class AsyncDnsServer(AsyncServer):
"""
try:
query = dns.message.from_wire(wire)
except dns.message.UnknownTSIGKey:
self._abort_if_on_dnspython_version_less_than_2_0_0()
self._abort_if_tsig_signed_query_received_unless_acknowledged()
query = _DnsMessageWithTsigDisabled.from_wire(wire)
except dns.exception.DNSException as exc:
logging.error("Invalid query from %s (%s): %s", peer, wire.hex(), exc)
return
@ -796,6 +869,26 @@ class AsyncDnsServer(AsyncServer):
response_length = struct.pack("!H", len(response))
yield response_length + response
def _abort_if_on_dnspython_version_less_than_2_0_0(self) -> None:
if dns.version.MAJOR < 2:
error = "Receiving TSIG signed queries requires dnspython >= 2.0.0; "
error += 'add `pytest.importorskip("dns", minversion="2.0.0")` '
error += "to the test module to skip this test."
raise RuntimeError(error)
def _abort_if_tsig_signed_query_received_unless_acknowledged(self) -> None:
if self._acknowledge_tsig_dnspython_hacks:
return
error = "TSIG-signed query received; "
error += "due to a bug in dnspython, this requires some hacking around; "
error += "you may experience unexpected behavior when dealing with TSIG; "
error += "TSIG validation is disabled, so any TSIG handling must be done "
error += "manually; pass `acknowledge_tsig_dnspython_hacks=True` to the "
error += "AsyncDnsServer constructor to acknowledge this and continue."
raise ValueError(error)
async def _prepare_responses(
self, qctx: QueryContext
) -> AsyncGenerator[Optional[Union[dns.message.Message, bytes]], None]: