bind9/bin/tests/system/isctest/query.py
Evan Hunt a2d74e7356
Make the RD flag optional in isctest.query()
Add an 'rd' parameter (default True) to isctest.query.create() so
that non-recursive queries can be sent with rd=False.

(cherry picked from commit 12e5113100)
2026-05-07 13:09:18 +02:00

184 lines
5.1 KiB
Python

# Copyright (C) Internet Systems Consortium, Inc. ("ISC")
#
# SPDX-License-Identifier: MPL-2.0
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, you can obtain one at https://mozilla.org/MPL/2.0/.
#
# See the COPYRIGHT file distributed with this work for additional
# information regarding copyright ownership.
from collections.abc import Callable
from typing import Any
import os
import time
import dns.exception
import dns.flags
import dns.message
import dns.name
import dns.query
import dns.rcode
import dns.rdataclass
import dns.rdatatype
import isctest.log
import isctest.run
QUERY_TIMEOUT = 10
def generic_query(
query_func: Callable[..., Any],
message: dns.message.Message,
ip: str,
port: int | None = None,
source: str | None = None,
timeout: int = QUERY_TIMEOUT,
attempts: int = 10,
expected_rcode: dns.rcode.Rcode | None = None,
verify: bool = False,
log_query: bool = True,
log_response: bool = True,
) -> Any:
def log_querymsg(exception: Exception | None = None) -> None:
"""
Helper for logging query message. Call this *after* query_func() has
been called, as it may modify the message, e.g. with a TSIG.
If an exception is provided, it will be logged as well.
"""
nonlocal log_query
if log_query:
isctest.log.debug(
f"isc.query.{query_func.__name__}(): query\n{message.to_text()}"
)
log_query = False # only log query once
if exception:
isctest.log.debug(
f"isc.query.{query_func.__name__}(): the '{exception}' exception raised"
)
if port is None:
if query_func.__name__ == "tls":
port = int(os.environ["TLSPORT"])
else:
port = int(os.environ["PORT"])
query_args = {
"q": message,
"where": ip,
"timeout": timeout,
"port": port,
"source": source,
}
if query_func.__name__ == "tls":
query_args["verify"] = verify
res = None
for attempt in range(attempts):
log_msg = (
f"isc.query.{query_func.__name__}(): ip={ip}, port={port}, source={source}, "
f"timeout={timeout}, attempts left={attempts-attempt}"
)
isctest.log.debug(log_msg)
exc = None
try:
res = query_func(**query_args)
except (dns.exception.Timeout, ConnectionRefusedError) as e:
exc = e
finally:
log_querymsg(exc)
if res:
if log_response:
isctest.log.debug(
f"isc.query.{query_func.__name__}(): response\n{res.to_text()}"
)
if res.rcode() == expected_rcode or expected_rcode is None:
return res
time.sleep(1)
if expected_rcode is not None:
last_rcode = dns.rcode.to_text(res.rcode()) if res else None
isctest.log.debug(
f"isc.query.{query_func.__name__}(): expected rcode={dns.rcode.to_text(expected_rcode)}, last rcode={last_rcode}"
)
raise dns.exception.Timeout
def udp(*args, **kwargs) -> Any:
return generic_query(dns.query.udp, *args, **kwargs)
def tcp(*args, **kwargs) -> Any:
return generic_query(dns.query.tcp, *args, **kwargs)
def tls(*args, **kwargs) -> Any:
try:
return generic_query(dns.query.tls, *args, **kwargs)
except TypeError as e:
raise RuntimeError(
"dnspython 2.5.0 or newer is required for isctest.query.tls()"
) from e
def create(
qname,
qtype,
qclass=dns.rdataclass.IN,
dnssec: bool = True,
rd: bool = True,
cd: bool = False,
ad: bool = True,
) -> dns.message.Message:
"""Create DNS query with defaults suitable for our tests."""
msg = dns.message.make_query(
qname, qtype, qclass, use_edns=True, want_dnssec=dnssec
)
msg.flags = 0
if rd:
msg.flags = dns.flags.RD
if ad:
msg.flags |= dns.flags.AD
if cd:
msg.flags |= dns.flags.CD
return msg
def wait_for_serial(server_ip, zone, expected_serial, timeout=30):
"""Wait until the server has the expected SOA serial for the zone.
Queries the server repeatedly until the SOA serial matches or the
timeout expires.
'server_ip' is the IP address to query (string).
'zone' is the zone name (string, with or without trailing dot).
'expected_serial' is the expected SOA serial number (int).
'timeout' is the maximum time to wait in seconds (default 30).
"""
query = create(zone, "SOA", dnssec=False)
def check():
res = tcp(query, server_ip)
soa = res.get_rrset(
res.answer,
dns.name.from_text(zone),
dns.rdataclass.IN,
dns.rdatatype.SOA,
)
return soa is not None and len(soa) == 1 and soa[0].serial == expected_serial
isctest.run.retry_with_timeout(
check,
timeout=timeout,
msg=f"timed out waiting for serial {expected_serial} at {server_ip} for {zone}",
)