mirror of
https://github.com/isc-projects/bind9.git
synced 2026-05-27 12:13:20 -04:00
[9.20] chg: test: Use isctest.asyncserver in the "pipelined" test
Replace the custom DNS server used in the "pipelined" system test with new code based on the isctest.asyncserver module. Backport of MR !11516 Merge branch 'backport-michal/pipelined-asyncserver-9.20' into 'bind-9.20' See merge request isc-projects/bind9!11552
This commit is contained in:
commit
abc11cf63c
2 changed files with 148 additions and 224 deletions
|
|
@ -266,6 +266,7 @@ class QueryContext:
|
|||
|
||||
query: dns.message.Message
|
||||
response: dns.message.Message
|
||||
socket: Peer
|
||||
peer: Peer
|
||||
protocol: DnsProtocol
|
||||
zone: Optional[dns.zone.Zone] = field(default=None, init=False)
|
||||
|
|
@ -787,6 +788,108 @@ class DomainHandler(ResponseHandler):
|
|||
return False
|
||||
|
||||
|
||||
class ForwarderHandler(ResponseHandler):
|
||||
"""
|
||||
A handler forwarding all received queries to another DNS server with an
|
||||
optional delay and then relaying the responses back to the original client.
|
||||
|
||||
Queries are currently always forwarded via UDP.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def target(self) -> str:
|
||||
"""
|
||||
The address of the DNS server to forward queries to.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def port(self) -> int:
|
||||
"""
|
||||
The port of the DNS server to forward queries to.
|
||||
|
||||
The default value of 0 causes the same port as the one used by this
|
||||
server for listening to be used.
|
||||
"""
|
||||
return 0
|
||||
|
||||
@property
|
||||
def delay(self) -> float:
|
||||
"""
|
||||
The number of seconds to wait before forwarding each query.
|
||||
"""
|
||||
return 0.0
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}(target: {self.target}:{self.port})"
|
||||
|
||||
class ForwarderProtocol(asyncio.DatagramProtocol):
|
||||
def __init__(self, query: bytes, response: asyncio.Future) -> None:
|
||||
self._query = query
|
||||
self._response = response
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
logging.debug("[OUT] %s", self._query.hex())
|
||||
cast(asyncio.DatagramTransport, transport).sendto(self._query)
|
||||
|
||||
def datagram_received(self, data: bytes, _: Tuple[str, int]) -> None:
|
||||
logging.debug("[IN] %s", data.hex())
|
||||
self._response.set_result(data)
|
||||
|
||||
async def get_responses(
|
||||
self, qctx: QueryContext
|
||||
) -> AsyncGenerator[ResponseAction, None]:
|
||||
loop = asyncio.get_running_loop()
|
||||
response = loop.create_future()
|
||||
forwarding_target = f"{self.target}:{self.port or qctx.socket.port}"
|
||||
|
||||
if self.delay > 0:
|
||||
logging.info(
|
||||
"Waiting %.1fs before forwarding %s query from %s to %s over UDP",
|
||||
self.delay,
|
||||
qctx.protocol.name,
|
||||
qctx.peer,
|
||||
forwarding_target,
|
||||
)
|
||||
await asyncio.sleep(self.delay)
|
||||
|
||||
logging.info(
|
||||
"Forwarding %s query from %s to %s over UDP",
|
||||
qctx.protocol.name,
|
||||
qctx.peer,
|
||||
forwarding_target,
|
||||
)
|
||||
|
||||
transport, _ = await loop.create_datagram_endpoint(
|
||||
lambda: self.ForwarderProtocol(qctx.query.to_wire(), response),
|
||||
local_addr=(qctx.socket.host, 0),
|
||||
remote_addr=(self.target, self.port or qctx.socket.port),
|
||||
)
|
||||
|
||||
try:
|
||||
await response
|
||||
finally:
|
||||
transport.close()
|
||||
|
||||
logging.info(
|
||||
"Relaying UDP response from %s to %s over %s",
|
||||
forwarding_target,
|
||||
qctx.peer,
|
||||
qctx.protocol.name,
|
||||
)
|
||||
|
||||
try:
|
||||
message = _DnsMessageWithTsigDisabled.from_wire(response.result())
|
||||
yield DnsResponseSend(message, acknowledge_hand_rolled_response=True)
|
||||
except dns.exception.DNSException:
|
||||
logging.warning(
|
||||
"Failed to parse response from %s as a DNS message, relaying it as raw bytes",
|
||||
forwarding_target,
|
||||
)
|
||||
yield BytesResponseSend(response.result())
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ZoneTreeNode:
|
||||
"""
|
||||
|
|
@ -1072,8 +1175,10 @@ class AsyncDnsServer(AsyncServer):
|
|||
self, wire: bytes, addr: Tuple[str, int], transport: asyncio.DatagramTransport
|
||||
) -> None:
|
||||
logging.debug("Received UDP message: %s", wire.hex())
|
||||
socket_info = transport.get_extra_info("sockname")
|
||||
socket = Peer(socket_info[0], socket_info[1])
|
||||
peer = Peer(addr[0], addr[1])
|
||||
responses = self._handle_query(wire, peer, DnsProtocol.UDP)
|
||||
responses = self._handle_query(wire, socket, peer, DnsProtocol.UDP)
|
||||
async for response in responses:
|
||||
logging.debug("Sending UDP message: %s", response.hex())
|
||||
transport.sendto(response, addr)
|
||||
|
|
@ -1170,39 +1275,39 @@ class AsyncDnsServer(AsyncServer):
|
|||
async def _send_tcp_response(
|
||||
self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes
|
||||
) -> None:
|
||||
responses = self._handle_query(wire, peer, DnsProtocol.TCP)
|
||||
socket_info = writer.get_extra_info("sockname")
|
||||
socket = Peer(socket_info[0], socket_info[1])
|
||||
responses = self._handle_query(wire, socket, peer, DnsProtocol.TCP)
|
||||
async for response in responses:
|
||||
logging.debug("Sending TCP response: %s", response.hex())
|
||||
writer.write(response)
|
||||
await writer.drain()
|
||||
|
||||
def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None:
|
||||
def _log_query(self, qctx: QueryContext) -> None:
|
||||
logging.info(
|
||||
"Received %s/%s/%s (ID=%d) query from %s (%s)",
|
||||
"Received %s/%s/%s (ID=%d) query from %s on %s (%s)",
|
||||
qctx.qname.to_text(omit_final_dot=True),
|
||||
dns.rdataclass.to_text(qctx.qclass),
|
||||
dns.rdatatype.to_text(qctx.qtype),
|
||||
qctx.query.id,
|
||||
peer,
|
||||
protocol.name,
|
||||
qctx.peer,
|
||||
qctx.socket,
|
||||
qctx.protocol.name,
|
||||
)
|
||||
logging.debug(
|
||||
"\n".join([f"[IN] {l}" for l in [""] + str(qctx.query).splitlines()])
|
||||
)
|
||||
|
||||
def _log_response(
|
||||
self,
|
||||
qctx: QueryContext,
|
||||
response: Optional[Union[dns.message.Message, bytes]],
|
||||
peer: Peer,
|
||||
protocol: DnsProtocol,
|
||||
self, qctx: QueryContext, response: Optional[Union[dns.message.Message, bytes]]
|
||||
) -> None:
|
||||
if not response:
|
||||
logging.info(
|
||||
"Not sending a response to query (ID=%d) from %s (%s)",
|
||||
"Not sending a response to query (ID=%d) from %s on %s (%s)",
|
||||
qctx.query.id,
|
||||
peer,
|
||||
protocol.name,
|
||||
qctx.peer,
|
||||
qctx.socket,
|
||||
qctx.protocol.name,
|
||||
)
|
||||
return
|
||||
|
||||
|
|
@ -1217,7 +1322,7 @@ class AsyncDnsServer(AsyncServer):
|
|||
qtype = "-"
|
||||
|
||||
logging.info(
|
||||
"Sending %s/%s/%s (ID=%d) response (%d/%d/%d/%d) to a query (ID=%d) from %s (%s)",
|
||||
"Sending %s/%s/%s (ID=%d) response (%d/%d/%d/%d) to a query (ID=%d) from %s on %s (%s)",
|
||||
qname,
|
||||
qclass,
|
||||
qtype,
|
||||
|
|
@ -1227,8 +1332,9 @@ class AsyncDnsServer(AsyncServer):
|
|||
len(response.authority),
|
||||
len(response.additional),
|
||||
qctx.query.id,
|
||||
peer,
|
||||
protocol.name,
|
||||
qctx.peer,
|
||||
qctx.socket,
|
||||
qctx.protocol.name,
|
||||
)
|
||||
logging.debug(
|
||||
"\n".join([f"[OUT] {l}" for l in [""] + str(response).splitlines()])
|
||||
|
|
@ -1236,16 +1342,17 @@ class AsyncDnsServer(AsyncServer):
|
|||
return
|
||||
|
||||
logging.info(
|
||||
"Sending response (%d bytes) to a query (ID=%d) from %s (%s)",
|
||||
"Sending response (%d bytes) to a query (ID=%d) from %s on %s (%s)",
|
||||
len(response),
|
||||
qctx.query.id,
|
||||
peer,
|
||||
protocol.name,
|
||||
qctx.peer,
|
||||
qctx.socket,
|
||||
qctx.protocol.name,
|
||||
)
|
||||
logging.debug("[OUT] %s", response.hex())
|
||||
|
||||
async def _handle_query(
|
||||
self, wire: bytes, peer: Peer, protocol: DnsProtocol
|
||||
self, wire: bytes, socket: Peer, peer: Peer, protocol: DnsProtocol
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
Yield wire data to send as a response over the established transport.
|
||||
|
|
@ -1256,11 +1363,11 @@ class AsyncDnsServer(AsyncServer):
|
|||
logging.error("Invalid query from %s (%s): %s", peer, wire.hex(), exc)
|
||||
return
|
||||
response_stub = _make_asyncserver_response(query)
|
||||
qctx = QueryContext(query, response_stub, peer, protocol)
|
||||
self._log_query(qctx, peer, protocol)
|
||||
qctx = QueryContext(query, response_stub, socket, peer, protocol)
|
||||
self._log_query(qctx)
|
||||
responses = self._prepare_responses(qctx)
|
||||
async for response in responses:
|
||||
self._log_response(qctx, response, peer, protocol)
|
||||
self._log_response(qctx, response)
|
||||
if response:
|
||||
if isinstance(response, dns.message.Message):
|
||||
response = response.to_wire(max_size=65535)
|
||||
|
|
|
|||
|
|
@ -1,211 +1,28 @@
|
|||
# Copyright (C) Internet Systems Consortium, Inc. ("ISC")
|
||||
#
|
||||
# SPDX-License-Identifier: MPL-2.0
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
# file, you can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
#
|
||||
# See the COPYRIGHT file distributed with this work for additional
|
||||
# information regarding copyright ownership.
|
||||
"""
|
||||
Copyright (C) Internet Systems Consortium, Inc. ("ISC")
|
||||
|
||||
############################################################################
|
||||
#
|
||||
# This tool acts as a TCP/UDP proxy and delays all incoming packets by 500
|
||||
# milliseconds.
|
||||
#
|
||||
# We use it to check pipelining - a client sents 8 questions over a
|
||||
# pipelined connection - that require asking a normal (examplea) and a
|
||||
# slow-responding (exampleb) servers:
|
||||
# a.examplea
|
||||
# a.exampleb
|
||||
# b.examplea
|
||||
# b.exampleb
|
||||
# c.examplea
|
||||
# c.exampleb
|
||||
# d.examplea
|
||||
# d.exampleb
|
||||
#
|
||||
# If pipelining works properly the answers will be returned out of order
|
||||
# with all answers from examplea returned first, and then all answers
|
||||
# from exampleb.
|
||||
#
|
||||
############################################################################
|
||||
SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
from __future__ import print_function
|
||||
This Source Code Form is subject to the terms of the Mozilla Public
|
||||
License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
file, you can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import select
|
||||
import signal
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import struct
|
||||
See the COPYRIGHT file distributed with this work for additional
|
||||
information regarding copyright ownership.
|
||||
"""
|
||||
|
||||
DELAY = 0.5
|
||||
THREADS = []
|
||||
from isctest.asyncserver import AsyncDnsServer, ForwarderHandler
|
||||
|
||||
|
||||
def log(msg):
|
||||
print(datetime.datetime.now().strftime("%d-%b-%Y %H:%M:%S.%f ") + msg)
|
||||
class ForwardToNs2(ForwarderHandler):
|
||||
target = "10.53.0.2"
|
||||
delay = 0.5
|
||||
|
||||
|
||||
def sigterm(*_):
|
||||
log("SIGTERM received, shutting down")
|
||||
for thread in THREADS:
|
||||
thread.close()
|
||||
thread.join()
|
||||
os.remove("ans.pid")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
class TCPDelayer(threading.Thread):
|
||||
"""For a given TCP connection conn we open a connection to (ip, port),
|
||||
and then we delay each incoming packet by DELAY by putting it in a
|
||||
queue.
|
||||
In the pipelined test TCP should not be used, but it's here for
|
||||
completnes.
|
||||
"""
|
||||
|
||||
def __init__(self, conn, ip, port):
|
||||
threading.Thread.__init__(self)
|
||||
self.conn = conn
|
||||
self.cconn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.cconn.connect((ip, port))
|
||||
self.queue = []
|
||||
self.running = True
|
||||
|
||||
def close(self):
|
||||
self.running = False
|
||||
|
||||
def run(self):
|
||||
while self.running:
|
||||
curr_timeout = 0.5
|
||||
try:
|
||||
curr_timeout = self.queue[0][0] - time.monotonic()
|
||||
except StopIteration:
|
||||
pass
|
||||
if curr_timeout > 0:
|
||||
if curr_timeout == 0:
|
||||
curr_timeout = 0.5
|
||||
rfds, _, _ = select.select(
|
||||
[self.conn, self.cconn], [], [], curr_timeout
|
||||
)
|
||||
if self.conn in rfds:
|
||||
data = self.conn.recv(65535)
|
||||
if not data:
|
||||
return
|
||||
self.queue.append((time.monotonic() + DELAY, data))
|
||||
if self.cconn in rfds:
|
||||
data = self.cconn.recv(65535)
|
||||
if not data == 0:
|
||||
return
|
||||
self.conn.send(data)
|
||||
try:
|
||||
while self.queue[0][0] - time.monotonic() < 0:
|
||||
_, data = self.queue.pop(0)
|
||||
self.cconn.send(data)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
|
||||
class UDPDelayer(threading.Thread):
|
||||
"""Every incoming UDP packet is put in a queue for DELAY time, then
|
||||
it's sent to (ip, port). We remember the query id to send the
|
||||
response we get to a proper source, responses are not delayed.
|
||||
"""
|
||||
|
||||
def __init__(self, usock, ip, port):
|
||||
threading.Thread.__init__(self)
|
||||
self.sock = usock
|
||||
self.csock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.dst = (ip, port)
|
||||
self.queue = []
|
||||
self.qid_mapping = {}
|
||||
self.running = True
|
||||
|
||||
def close(self):
|
||||
self.running = False
|
||||
|
||||
def run(self):
|
||||
while self.running:
|
||||
curr_timeout = 0.5
|
||||
if self.queue:
|
||||
curr_timeout = self.queue[0][0] - time.monotonic()
|
||||
if curr_timeout >= 0:
|
||||
if curr_timeout == 0:
|
||||
curr_timeout = 0.5
|
||||
rfds, _, _ = select.select(
|
||||
[self.sock, self.csock], [], [], curr_timeout
|
||||
)
|
||||
if self.sock in rfds:
|
||||
data, addr = self.sock.recvfrom(65535)
|
||||
if not data:
|
||||
return
|
||||
self.queue.append((time.monotonic() + DELAY, data))
|
||||
qid = struct.unpack(">H", data[:2])[0]
|
||||
log("Received a query from %s, queryid %d" % (str(addr), qid))
|
||||
self.qid_mapping[qid] = addr
|
||||
if self.csock in rfds:
|
||||
data, addr = self.csock.recvfrom(65535)
|
||||
if not data:
|
||||
return
|
||||
qid = struct.unpack(">H", data[:2])[0]
|
||||
dst = self.qid_mapping.get(qid)
|
||||
if dst is not None:
|
||||
self.sock.sendto(data, dst)
|
||||
log(
|
||||
"Received a response from %s, queryid %d, sending to %s"
|
||||
% (str(addr), qid, str(dst))
|
||||
)
|
||||
while self.queue and self.queue[0][0] - time.monotonic() < 0:
|
||||
_, data = self.queue.pop(0)
|
||||
qid = struct.unpack(">H", data[:2])[0]
|
||||
log("Sending a query to %s, queryid %d" % (str(self.dst), qid))
|
||||
self.csock.sendto(data, self.dst)
|
||||
|
||||
|
||||
def main():
|
||||
signal.signal(signal.SIGTERM, sigterm)
|
||||
signal.signal(signal.SIGINT, sigterm)
|
||||
|
||||
with open("ans.pid", "w") as pidfile:
|
||||
print(os.getpid(), file=pidfile)
|
||||
|
||||
listenip = "10.53.0.5"
|
||||
serverip = "10.53.0.2"
|
||||
|
||||
try:
|
||||
port = int(os.environ["PORT"])
|
||||
except KeyError:
|
||||
port = 5300
|
||||
|
||||
log("Listening on %s:%d" % (listenip, port))
|
||||
|
||||
usock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
usock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
usock.bind((listenip, port))
|
||||
thread = UDPDelayer(usock, serverip, port)
|
||||
thread.start()
|
||||
THREADS.append(thread)
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind((listenip, port))
|
||||
sock.listen(1)
|
||||
sock.settimeout(1)
|
||||
|
||||
while True:
|
||||
try:
|
||||
clientsock, _ = sock.accept()
|
||||
log("Accepted connection from %s" % clientsock)
|
||||
thread = TCPDelayer(clientsock, serverip, port)
|
||||
thread.start()
|
||||
THREADS.append(thread)
|
||||
except socket.timeout:
|
||||
pass
|
||||
def main() -> None:
|
||||
server = AsyncDnsServer()
|
||||
server.install_response_handlers(ForwardToNs2())
|
||||
server.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in a new issue