Replace Optional[T] with T | None

Generated with: ruff check --extend-select UP045 --fix && black .

(cherry picked from commit fe38515ad0)
This commit is contained in:
Štěpán Balážik 2026-02-09 15:46:40 +01:00
parent 89ce3b5e74
commit 17cf986396
16 changed files with 86 additions and 95 deletions

View file

@ -66,11 +66,11 @@ class SetSpoofingModeCommand(ControlCommand):
control_subdomain = "set-spoofing-mode"
def __init__(self) -> None:
self._current_handler: Optional[ResponseSpoofer] = None
self._current_handler: ResponseSpoofer | None = None
def handle(
self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
) -> Optional[str]:
) -> str | None:
if len(args) != 1:
qctx.response.set_rcode(dns.rcode.SERVFAIL)
return "invalid control command"

View file

@ -13,7 +13,7 @@ information regarding copyright ownership.
from dataclasses import dataclass
from enum import Enum
from typing import AsyncGenerator, Optional
from typing import AsyncGenerator
import abc
import logging
@ -297,11 +297,11 @@ class ChainSetupCommand(ControlCommand):
control_subdomain = "setup-chain"
def __init__(self) -> None:
self._current_handler: Optional[ChainResponseHandler] = None
self._current_handler: ChainResponseHandler | None = None
def handle(
self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
) -> Optional[str]:
) -> str | None:
try:
actions, selectors = self._parse_args(args)
except ValueError as exc:

View file

@ -9,7 +9,7 @@
# See the COPYRIGHT file distributed with this work for additional
# information regarding copyright ownership.
from typing import AsyncGenerator, Optional
from typing import AsyncGenerator
import logging
@ -69,7 +69,7 @@ class ResponseSequenceCommand(ControlCommand):
control_subdomain = "response-sequence"
def __init__(self) -> None:
self._current_handler: Optional[ResponseHandler] = None
self._current_handler: ResponseHandler | None = None
def handle(
self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext

View file

@ -12,7 +12,7 @@ information regarding copyright ownership.
"""
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, Callable, Coroutine, Optional, Sequence, cast
from typing import Any, AsyncGenerator, Callable, Coroutine, Sequence, cast
import abc
import asyncio
@ -62,7 +62,7 @@ class _AsyncUdpHandler(asyncio.DatagramProtocol):
self,
handler: _UdpHandler,
) -> None:
self._transport: Optional[asyncio.DatagramTransport] = None
self._transport: asyncio.DatagramTransport | None = None
self._handler: _UdpHandler = handler
def connection_made(self, transport: asyncio.BaseTransport) -> None:
@ -96,9 +96,9 @@ class AsyncServer:
def __init__(
self,
udp_handler: Optional[_UdpHandler],
tcp_handler: Optional[_TcpHandler],
pidfile: Optional[str] = None,
udp_handler: _UdpHandler | None,
tcp_handler: _TcpHandler | None,
pidfile: str | None = None,
) -> None:
logging.basicConfig(
format="%(asctime)s %(levelname)8s %(message)s",
@ -122,10 +122,10 @@ class AsyncServer:
self._ip_addresses: tuple[str, str] = (ipv4_address, ipv6_address)
self._port: int = port
self._udp_handler: Optional[_UdpHandler] = udp_handler
self._tcp_handler: Optional[_TcpHandler] = tcp_handler
self._pidfile: Optional[str] = pidfile
self._work_done: Optional[asyncio.Future] = None
self._udp_handler: _UdpHandler | None = udp_handler
self._tcp_handler: _TcpHandler | None = tcp_handler
self._pidfile: str | None = pidfile
self._work_done: asyncio.Future | None = None
def _get_ipv4_address_from_directory_name(self) -> str:
containing_directory = pathlib.Path().absolute().stem
@ -256,15 +256,13 @@ class QueryContext:
socket: Peer
peer: Peer
protocol: DnsProtocol
zone: Optional[dns.zone.Zone] = field(default=None, init=False)
soa: Optional[dns.rrset.RRset] = field(default=None, init=False)
node: Optional[dns.node.Node] = field(default=None, init=False)
answer: Optional[dns.rdataset.Rdataset] = field(default=None, init=False)
alias: Optional[dns.name.Name] = field(default=None, init=False)
_initialized_response: Optional[dns.message.Message] = field(
default=None, init=False
)
_initialized_response_with_zone_data: Optional[dns.message.Message] = field(
zone: dns.zone.Zone | None = field(default=None, init=False)
soa: dns.rrset.RRset | None = field(default=None, init=False)
node: dns.node.Node | None = field(default=None, init=False)
answer: dns.rdataset.Rdataset | None = field(default=None, init=False)
alias: dns.name.Name | None = field(default=None, init=False)
_initialized_response: dns.message.Message | None = field(default=None, init=False)
_initialized_response_with_zone_data: dns.message.Message | None = field(
default=None, init=False
)
@ -309,7 +307,7 @@ class ResponseAction(abc.ABC):
"""
@abc.abstractmethod
async def perform(self) -> Optional[dns.message.Message | bytes]:
async def perform(self) -> dns.message.Message | bytes | None:
"""
This method is expected to carry out arbitrary actions (e.g. wait for a
specific amount of time, modify the answer, etc.) and then return the
@ -332,11 +330,11 @@ class DnsResponseSend(ResponseAction):
"""
response: dns.message.Message
authoritative: Optional[bool] = None
authoritative: bool | None = None
delay: float = 0.0
acknowledge_hand_rolled_response: bool = False
async def perform(self) -> Optional[dns.message.Message | bytes]:
async def perform(self) -> dns.message.Message | bytes | None:
"""
Yield a potentially delayed response that is a dns.message.Message.
"""
@ -382,7 +380,7 @@ class BytesResponseSend(ResponseAction):
response: bytes
delay: float = 0.0
async def perform(self) -> Optional[dns.message.Message | bytes]:
async def perform(self) -> dns.message.Message | bytes | None:
"""
Yield a potentially delayed response that is a sequence of bytes.
"""
@ -399,7 +397,7 @@ class ResponseDrop(ResponseAction):
Action which does nothing - as if a packet was dropped.
"""
async def perform(self) -> Optional[dns.message.Message | bytes]:
async def perform(self) -> dns.message.Message | bytes | None:
return None
@ -417,7 +415,7 @@ class CloseConnection(ResponseAction):
delay: float = 0.0
async def perform(self) -> Optional[dns.message.Message | bytes]:
async def perform(self) -> dns.message.Message | bytes | None:
if self.delay > 0:
logging.info("Waiting %.1fs before closing TCP connection", self.delay)
await asyncio.sleep(self.delay)
@ -674,7 +672,7 @@ class StaticResponseHandler(ResponseHandler):
"""
@property
def rcode(self) -> Optional[dns.rcode.Rcode]:
def rcode(self) -> dns.rcode.Rcode | None:
"""
Optional RCODE to be set in the response.
"""
@ -702,7 +700,7 @@ class StaticResponseHandler(ResponseHandler):
return []
@property
def authoritative(self) -> Optional[bool]:
def authoritative(self) -> bool | None:
"""
Whether to set the AA bit in the response.
"""
@ -752,7 +750,7 @@ class DomainHandler(ResponseHandler):
self._domains: list[dns.name.Name] = sorted(
[dns.name.from_text(d) for d in self.domains], reverse=True
)
self._matched_domain: Optional[dns.name.Name] = None
self._matched_domain: dns.name.Name | None = None
@property
def matched_domain(self) -> dns.name.Name:
@ -883,7 +881,7 @@ class _ZoneTreeNode:
A node representing a zone with one origin.
"""
zone: Optional[dns.zone.Zone]
zone: dns.zone.Zone | None
children: list["_ZoneTreeNode"] = field(default_factory=list)
@ -934,7 +932,7 @@ class _ZoneTree:
node_from.children.remove(child)
node_to.children.append(child)
def find_best_zone(self, name: dns.name.Name) -> Optional[dns.zone.Zone]:
def find_best_zone(self, name: dns.name.Name) -> dns.zone.Zone | None:
"""
Return the closest matching zone (if any) for the domain name.
"""
@ -952,7 +950,7 @@ class _DnsMessageWithTsigDisabled(dns.message.Message):
"""
class _DisableTsigHandling(contextlib.ContextDecorator):
def __init__(self, message: Optional[dns.message.Message] = None) -> None:
def __init__(self, message: dns.message.Message | None = None) -> None:
self.original_tsig_sign = dns.tsig.sign
self.original_tsig_validate = dns.tsig.validate
if message:
@ -1049,7 +1047,7 @@ class AsyncDnsServer(AsyncServer):
super().__init__(self._handle_udp, self._handle_tcp, "ans.pid")
self._zone_tree: _ZoneTree = _ZoneTree()
self._connection_handler: Optional[ConnectionHandler] = None
self._connection_handler: ConnectionHandler | None = None
self._response_handlers: list[ResponseHandler] = []
self._default_rcode = default_rcode
self._default_aa = default_aa
@ -1202,7 +1200,7 @@ class AsyncDnsServer(AsyncServer):
async def _read_tcp_query(
self, reader: asyncio.StreamReader, peer: Peer
) -> Optional[bytes]:
) -> bytes | None:
wire_length = await self._read_tcp_query_wire_length(reader, peer)
if not wire_length:
return None
@ -1211,7 +1209,7 @@ class AsyncDnsServer(AsyncServer):
async def _read_tcp_query_wire_length(
self, reader: asyncio.StreamReader, peer: Peer
) -> Optional[int]:
) -> int | None:
logging.debug("Receiving TCP message length from %s...", peer)
wire_length_bytes = await self._read_tcp_octets(reader, peer, 2)
@ -1224,7 +1222,7 @@ class AsyncDnsServer(AsyncServer):
async def _read_tcp_query_wire(
self, reader: asyncio.StreamReader, peer: Peer, wire_length: int
) -> Optional[bytes]:
) -> bytes | None:
logging.debug("Receiving TCP message (%d octets) from %s...", wire_length, peer)
wire = await self._read_tcp_octets(reader, peer, wire_length)
@ -1237,7 +1235,7 @@ class AsyncDnsServer(AsyncServer):
async def _read_tcp_octets(
self, reader: asyncio.StreamReader, peer: Peer, expected: int
) -> Optional[bytes]:
) -> bytes | None:
buffer = b""
while len(buffer) < expected:
@ -1286,7 +1284,7 @@ class AsyncDnsServer(AsyncServer):
)
def _log_response(
self, qctx: QueryContext, response: Optional[dns.message.Message | bytes]
self, qctx: QueryContext, response: dns.message.Message | bytes | None
) -> None:
if not response:
logging.info(
@ -1386,7 +1384,7 @@ class AsyncDnsServer(AsyncServer):
async def _prepare_responses(
self, qctx: QueryContext
) -> AsyncGenerator[Optional[dns.message.Message | bytes], None]:
) -> AsyncGenerator[dns.message.Message | bytes | None, None]:
"""
Yield response(s) either from response handlers or zone data.
"""
@ -1600,7 +1598,7 @@ class ControllableAsyncDnsServer(AsyncDnsServer):
async def _prepare_responses(
self, qctx: QueryContext
) -> AsyncGenerator[Optional[dns.message.Message | bytes], None]:
) -> AsyncGenerator[dns.message.Message | bytes | None, None]:
"""
Detect and handle control queries, falling back to normal processing
for non-control queries.
@ -1613,9 +1611,7 @@ class ControllableAsyncDnsServer(AsyncDnsServer):
async for response in super()._prepare_responses(qctx):
yield response
def _handle_control_command(
self, qctx: QueryContext
) -> Optional[dns.message.Message]:
def _handle_control_command(self, qctx: QueryContext) -> dns.message.Message | None:
"""
Detect and handle control queries.
@ -1691,7 +1687,7 @@ class ControlCommand(abc.ABC):
@abc.abstractmethod
def handle(
self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
) -> Optional[str]:
) -> str | None:
"""
This method is expected to carry out arbitrary actions in response to a
control query. Note that it is invoked synchronously (it is not a
@ -1729,11 +1725,11 @@ class ToggleResponsesCommand(ControlCommand):
control_subdomain = "send-responses"
def __init__(self) -> None:
self._current_handler: Optional[IgnoreAllQueries] = None
self._current_handler: IgnoreAllQueries | None = None
def handle(
self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
) -> Optional[str]:
) -> str | None:
if len(args) != 1:
logging.error("Invalid %s query %s", self, qctx.qname)
qctx.response.set_rcode(dns.rcode.SERVFAIL)
@ -1777,7 +1773,7 @@ class SwitchControlCommand(ControlCommand):
def handle(
self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
) -> Optional[str]:
) -> str | None:
if len(args) != 1 or args[0] not in self._handler_mapping:
logging.error("Invalid %s query %s", self, qctx.qname)
qctx.response.set_rcode(dns.rcode.SERVFAIL)

View file

@ -9,7 +9,7 @@
# See the COPYRIGHT file distributed with this work for additional
# information regarding copyright ownership.
from typing import Optional, cast
from typing import cast
import difflib
import shutil
@ -89,9 +89,7 @@ def noede(message: dns.message.Message) -> None:
assert not ede_options, f"unexpected EDE options {ede_options} in {message}"
def ede(
message: dns.message.Message, code: EDECode, text: Optional[str] = None
) -> None:
def ede(message: dns.message.Message, code: EDECode, text: str | None = None) -> None:
"""Check if message contains expected EDE code (and its text)."""
msg_opts = _extract_ede_options(message)
matching_opts = [opt for opt in msg_opts if opt.code == code]
@ -137,7 +135,7 @@ def same_answer(res1: dns.message.Message, res2: dns.message.Message):
def rrsets_equal(
first_rrset: dns.rrset.RRset,
second_rrset: dns.rrset.RRset,
compare_ttl: Optional[bool] = False,
compare_ttl: bool | None = False,
) -> None:
"""Compare two RRset (optionally including TTL)"""
@ -166,7 +164,7 @@ def rrsets_equal(
def zones_equal(
first_zone: dns.zone.Zone,
second_zone: dns.zone.Zone,
compare_ttl: Optional[bool] = False,
compare_ttl: bool | None = False,
) -> None:
"""Compare two zones (optionally including TTL)"""

View file

@ -12,7 +12,7 @@
# information regarding copyright ownership.
from pathlib import Path
from typing import NamedTuple, Optional
from typing import NamedTuple
import os
import re
@ -53,8 +53,8 @@ class NamedInstance:
def __init__(
self,
identifier: str,
num: Optional[int] = None,
ports: Optional[NamedPorts] = None,
num: int | None = None,
ports: NamedPorts | None = None,
) -> None:
"""
`identifier` is the name of the instance's directory
@ -94,7 +94,7 @@ class NamedInstance:
return f"10.53.0.{self.num}"
@staticmethod
def _identifier_to_num(identifier: str, num: Optional[int] = None) -> int:
def _identifier_to_num(identifier: str, num: int | None = None) -> int:
regex_match = re.match(r"^ns(?P<index>[0-9]{1,2})$", identifier)
if not regex_match:
if num is None:
@ -175,7 +175,7 @@ class NamedInstance:
watcher.wait_for_line("all zones loaded")
return cmd
def stop(self, args: Optional[list[str]] = None) -> None:
def stop(self, args: list[str] | None = None) -> None:
"""Stop the instance."""
args = args or []
perl(
@ -183,7 +183,7 @@ class NamedInstance:
[self.system_test_name, self.identifier] + args,
)
def start(self, args: Optional[list[str]] = None) -> None:
def start(self, args: list[str] | None = None) -> None:
"""Start the instance."""
args = args or []
perl(

View file

@ -13,7 +13,6 @@ from datetime import datetime, timedelta, timezone
from functools import total_ordering
from pathlib import Path
from re import compile as Re
from typing import Optional
import glob
import os
@ -324,7 +323,7 @@ class Key:
operations for KASP tests.
"""
def __init__(self, name: str, keydir: Optional[str | Path] = None):
def __init__(self, name: str, keydir: str | Path | None = None):
self.name = name
if keydir is None:
self.keydir = Path()
@ -339,7 +338,7 @@ class Key:
def get_timing(
self, metadata: str, must_exist: bool = True
) -> Optional[KeyTimingMetadata]:
) -> KeyTimingMetadata | None:
regex = rf";\s+{metadata}:\s+(\d+).*"
with open(self.keyfile, "r", encoding="utf-8") as file:
for line in file:
@ -1503,7 +1502,7 @@ def next_key_event_equals(server, zone, next_event):
def keydir_to_keylist(
zone: Optional[str], keydir: Optional[str] = None, in_use: bool = False
zone: str | None, keydir: str | None = None, in_use: bool = False
) -> list[Key]:
"""
Retrieve all keys from the key files in a directory. If 'zone' is None,
@ -1544,7 +1543,7 @@ def keydir_to_keylist(
return [k for k in all_keys if used(k)]
def keystr_to_keylist(keystr: str, keydir: Optional[str] = None) -> list[Key]:
def keystr_to_keylist(keystr: str, keydir: str | None = None) -> list[Key]:
return [Key(name, keydir) for name in keystr.split()]

View file

@ -9,8 +9,7 @@
# See the COPYRIGHT file distributed with this work for additional
# information regarding copyright ownership.
from typing import Any, Match, Optional, Pattern, TextIO, TypeAlias, TypeVar
from typing import Any, Match, Pattern, TextIO, TypeAlias, TypeVar
import abc
import os
@ -63,8 +62,8 @@ class WatchLog(abc.ABC):
...
isctest.log.watchlog.WatchLogException: timeout must be greater than 0
"""
self._fd: Optional[TextIO] = None
self._reader: Optional[LineReader] = None
self._fd: TextIO | None = None
self._reader: LineReader | None = None
self._path = path
self._wait_function_called = False
if timeout <= 0.0:

View file

@ -9,7 +9,7 @@
# See the COPYRIGHT file distributed with this work for additional
# information regarding copyright ownership.
from typing import Any, Callable, Optional
from typing import Any, Callable
import os
import time
@ -26,11 +26,11 @@ def generic_query(
query_func: Callable[..., Any],
message: dns.message.Message,
ip: str,
port: Optional[int] = None,
source: Optional[str] = None,
port: int | None = None,
source: str | None = None,
timeout: int = QUERY_TIMEOUT,
attempts: int = 10,
expected_rcode: Optional[dns.rcode.Rcode] = None,
expected_rcode: dns.rcode.Rcode | None = None,
verify: bool = False,
log_query: bool = True,
log_response: bool = True,

View file

@ -10,7 +10,6 @@
# information regarding copyright ownership.
from pathlib import Path
from typing import Optional
import os
import subprocess
@ -40,9 +39,9 @@ def cmd(
stderr=subprocess.PIPE,
log_stdout=True,
log_stderr=True,
input_text: Optional[bytes] = None,
input_text: bytes | None = None,
raise_on_exception=True,
env: Optional[dict] = None,
env: dict | None = None,
) -> CmdResult:
"""Execute a command with given args as subprocess."""
isctest.log.debug(f"isctest.run.cmd(): {' '.join(args)}")
@ -98,7 +97,7 @@ class EnvCmd:
def _run_script(
interpreter: str,
script: str,
args: Optional[list[str]] = None,
args: list[str] | None = None,
):
if args is None:
args = []
@ -130,12 +129,12 @@ def _run_script(
isctest.log.debug(" exited with %d", returncode)
def shell(script: str, args: Optional[list[str]] = None) -> None:
def shell(script: str, args: list[str] | None = None) -> None:
"""Run a given script with system's shell interpreter."""
_run_script(os.environ["SHELL"], script, args)
def perl(script: str, args: Optional[list[str]] = None) -> None:
def perl(script: str, args: list[str] | None = None) -> None:
"""Run a given script with system's perl interpreter."""
_run_script(os.environ["PERL"], script, args)

View file

@ -13,7 +13,7 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
from typing import Any
import jinja2
@ -44,8 +44,8 @@ class TemplateEngine:
def render(
self,
output: str,
data: Optional[dict[str, Any]] = None,
template: Optional[str] = None,
data: dict[str, Any] | None = None,
template: str | None = None,
) -> None:
"""
Render `output` file from jinja `template` and fill in the `data`. The
@ -69,7 +69,7 @@ class TemplateEngine:
stream = self.j2env.get_template(template).stream(data)
stream.dump(output, encoding="utf-8")
def render_auto(self, data: Optional[dict[str, Any]] = None):
def render_auto(self, data: dict[str, Any] | None = None):
"""
Render all *.j2 templates with default (and optionally the provided)
values and write the output to files without the .j2 extensions.

View file

@ -12,7 +12,7 @@
# information regarding copyright ownership.
from re import compile as Re
from typing import Iterator, Match, Optional, Pattern, TextIO
from typing import Iterator, Match, Pattern, TextIO
import abc
import re
@ -150,7 +150,7 @@ class LineReader(Grep):
self._stream = stream
self._linebuf = ""
def readline(self) -> Optional[str]:
def readline(self) -> str | None:
"""
Wrapper around io.readline() function to handle unfinished lines.

View file

@ -9,7 +9,7 @@
# See the COPYRIGHT file distributed with this work for additional
# information regarding copyright ownership.
from typing import NamedTuple, Optional
from typing import NamedTuple
import os
import platform
@ -273,7 +273,7 @@ def _algorithms_env(algs: AlgorithmSet, name: str) -> dict[str, str]:
return algs_env
def set_algorithm_set(name: Optional[str]):
def set_algorithm_set(name: str | None):
if name is None:
name = "stable"
assert name in ALGORITHM_SETS, f'ALGORITHM_SET "{name}" unknown'

View file

@ -10,7 +10,6 @@
# information regarding copyright ownership.
from re import compile as Re
from typing import Optional
import os
@ -24,7 +23,7 @@ OPENSSL_VARS = {
}
def parse_openssl_config(path: Optional[str]):
def parse_openssl_config(path: str | None):
if path is None or not os.path.exists(path):
OPENSSL_VARS["ENGINE_ARG"] = None
OPENSSL_VARS["SOFTHSM2_MODULE"] = None

View file

@ -17,7 +17,7 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Container, Iterable, Optional
from typing import Container, Iterable
import os
@ -291,7 +291,7 @@ class NSEC3Params:
algorithm: int
flags: int
iterations: int
salt: Optional[bytes]
salt: bytes | None
class NSEC3Checker:

View file

@ -99,6 +99,7 @@ lint.select = [
# unnecessary `typing` imports
"UP006",
"UP007",
"UP045",
# f-strings
"UP031",
"UP032",