Add type annotations to the certbot package (part 2) (#9085)

* Extract from #9084

* Cast/ignore types during the transition

* Clean up

* Fix assertion

* Update certbot/certbot/display/ops.py

Co-authored-by: alexzorin <alex@zor.io>

* Use sequence

* Improve documentation of "default" in display

* Fix contract

* Fix types

* Fix type

* Fix type

* Update certbot/certbot/display/ops.py

Co-authored-by: alexzorin <alex@zor.io>

Co-authored-by: alexzorin <alex@zor.io>
This commit is contained in:
Adrien Ferrand 2021-11-24 08:33:09 +01:00 committed by GitHub
parent d1821b3ad7
commit 19147e1b8c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 293 additions and 210 deletions

View file

@ -16,9 +16,8 @@ from typing import Union
from acme import challenges
from certbot import errors
from certbot import interfaces
from certbot import util
from certbot.achallenges import KeyAuthorizationAnnotatedChallenge # pylint: disable=unused-import
from certbot.achallenges import KeyAuthorizationAnnotatedChallenge
from certbot.compat import filesystem
from certbot.compat import os
from certbot.display import util as display_util
@ -116,7 +115,7 @@ class OsOptions:
# TODO: Add directives to sites-enabled... not sites-available.
# sites-available doesn't allow immediate find_dir search even with save()
# and load()
class ApacheConfigurator(common.Installer, interfaces.Authenticator):
class ApacheConfigurator(common.Configurator):
"""Apache configurator.
:ivar config: Configuration.

View file

@ -3,6 +3,7 @@ import errno
import logging
from typing import List
from typing import Set
from typing import TYPE_CHECKING
from certbot import errors
from certbot.compat import filesystem
@ -11,6 +12,9 @@ from certbot.plugins import common
from certbot_apache._internal.obj import VirtualHost # pylint: disable=unused-import
from certbot_apache._internal.parser import get_aug_path
if TYPE_CHECKING:
from certbot_apache._internal.configurator import ApacheConfigurator # pragma: no cover
logger = logging.getLogger(__name__)
@ -46,8 +50,9 @@ class ApacheHttp01(common.ChallengePerformer):
</Location>
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, configurator: "ApacheConfigurator") -> None:
super().__init__(configurator)
self.configurator: "ApacheConfigurator"
self.challenge_conf_pre = os.path.join(
self.configurator.conf("challenge-location"),
"le_http_01_challenge_pre.conf")

View file

@ -2,6 +2,7 @@
import logging
import re
from typing import Optional
from typing import Union
from lexicon.providers import linode
from lexicon.providers import linode4
@ -58,7 +59,7 @@ class Authenticator(dns_common.DNSAuthenticator):
if not self.credentials: # pragma: no cover
raise errors.Error("Plugin has not been prepared.")
api_key = self.credentials.conf('key')
api_version = self.credentials.conf('version')
api_version: Optional[Union[str, int]] = self.credentials.conf('version')
if api_version == '':
api_version = None

View file

@ -20,7 +20,6 @@ from acme import challenges
from acme import crypto_util as acme_crypto_util
from certbot import crypto_util
from certbot import errors
from certbot import interfaces
from certbot import util
from certbot.display import util as display_util
from certbot.compat import os
@ -42,7 +41,7 @@ NO_SSL_MODIFIER = 4
logger = logging.getLogger(__name__)
class NginxConfigurator(common.Installer, interfaces.Authenticator):
class NginxConfigurator(common.Configurator):
"""Nginx configurator.
.. todo:: Add proper support for comments in the config. Currently,

View file

@ -4,6 +4,7 @@ import io
import logging
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from acme import challenges
from certbot import achallenges
@ -13,6 +14,9 @@ from certbot.plugins import common
from certbot_nginx._internal import nginxparser
from certbot_nginx._internal import obj
if TYPE_CHECKING:
from certbot_nginx._internal.configurator import NginxConfigurator
logger = logging.getLogger(__name__)
@ -36,8 +40,9 @@ class NginxHttp01(common.ChallengePerformer):
"""
def __init__(self, configurator):
def __init__(self, configurator: "NginxConfigurator") -> None:
super().__init__(configurator)
self.configurator: "NginxConfigurator"
self.challenge_conf = os.path.join(
configurator.config.config_dir, "le_http_01_cert_challenge.conf")

View file

@ -5,15 +5,18 @@ import errno
import os # pylint: disable=os-module-forbidden
import stat
import sys
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
try:
import ntsecuritycon
import win32security
import win32con
import win32api
import win32file
import pywintypes
import win32api
import win32con
import win32file
import win32security
import winerror
except ImportError:
POSIX_MODE = True
@ -28,7 +31,7 @@ else:
# that could happen with this kind of pattern.
class _WindowsUmask:
"""Store the current umask to apply on Windows"""
def __init__(self):
def __init__(self) -> None:
self.mask = 0o022
@ -531,7 +534,7 @@ def has_min_permissions(path: str, min_mode: int) -> bool:
return True
def _win_is_executable(path):
def _win_is_executable(path: str) -> bool:
if not os.path.isfile(path):
return False
@ -547,7 +550,7 @@ def _win_is_executable(path):
return mode & ntsecuritycon.FILE_GENERIC_EXECUTE == ntsecuritycon.FILE_GENERIC_EXECUTE
def _apply_win_mode(file_path, mode):
def _apply_win_mode(file_path: str, mode: int) -> None:
"""
This function converts the given POSIX mode into a Windows ACL list, and applies it to the
file given its path. If the given path is a symbolic link, it will resolved to apply the
@ -566,7 +569,7 @@ def _apply_win_mode(file_path, mode):
win32security.SetFileSecurity(file_path, win32security.DACL_SECURITY_INFORMATION, security)
def _generate_dacl(user_sid, mode, mask=None):
def _generate_dacl(user_sid: Any, mode: int, mask: Optional[int] = None) -> Any:
if mask:
mode = mode & (0o777 - mask)
analysis = _analyze_mode(mode)
@ -602,7 +605,7 @@ def _generate_dacl(user_sid, mode, mask=None):
return dacl
def _analyze_mode(mode):
def _analyze_mode(mode: int) -> Dict[str, Dict[str, int]]:
return {
'user': {
'read': mode & stat.S_IRUSR,
@ -617,7 +620,7 @@ def _analyze_mode(mode):
}
def _copy_win_ownership(src, dst):
def _copy_win_ownership(src: str, dst: str) -> None:
# Resolve symbolic links
src = realpath(src)
@ -632,7 +635,7 @@ def _copy_win_ownership(src, dst):
win32security.SetFileSecurity(dst, win32security.OWNER_SECURITY_INFORMATION, security_dst)
def _copy_win_mode(src, dst):
def _copy_win_mode(src: str, dst: str) -> None:
# Resolve symbolic links
src = realpath(src)
@ -645,7 +648,7 @@ def _copy_win_mode(src, dst):
win32security.SetFileSecurity(dst, win32security.DACL_SECURITY_INFORMATION, security_dst)
def _generate_windows_flags(rights_desc):
def _generate_windows_flags(rights_desc: Dict[str, int]) -> int:
# Some notes about how each POSIX right is interpreted.
#
# For the rights read and execute, we have a pretty bijective relation between
@ -676,7 +679,7 @@ def _generate_windows_flags(rights_desc):
return flag
def _check_win_mode(file_path, mode):
def _check_win_mode(file_path: str, mode: int) -> bool:
# Resolve symbolic links
file_path = realpath(file_path)
# Get current dacl file
@ -698,7 +701,7 @@ def _check_win_mode(file_path, mode):
return _compare_dacls(dacl, ref_dacl)
def _compare_dacls(dacl1, dacl2):
def _compare_dacls(dacl1: Any, dacl2: Any) -> bool:
"""
This method compare the two given DACLs to check if they are identical.
Identical means here that they contains the same set of ACEs in the same order.
@ -707,7 +710,7 @@ def _compare_dacls(dacl1, dacl2):
[dacl2.GetAce(index) for index in range(dacl2.GetAceCount())])
def _get_current_user():
def _get_current_user() -> Any:
"""
Return the pySID corresponding to the current user.
"""

View file

@ -8,17 +8,18 @@ import logging
import select
import subprocess
import sys
import warnings
from typing import Optional
from typing import Tuple
import warnings
from certbot import errors
from certbot.compat import os
try:
from win32com.shell import shell as shellwin32
from win32console import GetStdHandle, STD_OUTPUT_HANDLE
from pywintypes import error as pywinerror
from win32com.shell import shell as shellwin32
from win32console import GetStdHandle
from win32console import STD_OUTPUT_HANDLE
POSIX_MODE = False
except ImportError: # pragma: no cover
POSIX_MODE = True
@ -61,7 +62,7 @@ def prepare_virtual_console() -> None:
logger.debug("Failed to set console mode", exc_info=True)
def readline_with_timeout(timeout: float, prompt: str) -> str:
def readline_with_timeout(timeout: float, prompt: Optional[str]) -> str:
"""
Read user input to return the first line entered, or raise after specified timeout.
@ -79,7 +80,7 @@ def readline_with_timeout(timeout: float, prompt: str) -> str:
rlist, _, _ = select.select([sys.stdin], [], [], timeout)
if not rlist:
raise errors.Error(
"Timed out waiting for answer to prompt '{0}'".format(prompt))
"Timed out waiting for answer to prompt '{0}'".format(prompt if prompt else ""))
return rlist[0].readline()
except OSError:
# Windows specific

View file

@ -1,9 +1,18 @@
"""Contains UI methods for LE user operations."""
import logging
from textwrap import indent
from typing import Any
from typing import Callable
from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from certbot import errors
from certbot import interfaces
from certbot import util
from certbot._internal import account
from certbot._internal.display import util as internal_display_util
from certbot.compat import os
from certbot.display import util as display_util
@ -11,7 +20,7 @@ from certbot.display import util as display_util
logger = logging.getLogger(__name__)
def get_email(invalid=False, optional=True):
def get_email(invalid: bool = False, optional: bool = True) -> str:
"""Prompt for valid email address.
:param bool invalid: True if an invalid address was provided by the user
@ -65,7 +74,7 @@ def get_email(invalid=False, optional=True):
invalid = bool(email)
def choose_account(accounts):
def choose_account(accounts: List[account.Account]) -> Optional[account.Account]:
"""Choose an account.
:param list accounts: Containing at least one
@ -81,22 +90,25 @@ def choose_account(accounts):
return None
def choose_values(values, question=None):
def choose_values(values: List[str], question: Optional[str] = None) -> List[str]:
"""Display screen to let user pick one or multiple values from the provided
list.
:param list values: Values to select from
:param str question: Question to ask to user while choosing values
:returns: List of selected values
:rtype: list
"""
code, items = display_util.checklist(question, tags=values, force_interactive=True)
code, items = display_util.checklist(question if question else "", tags=values,
force_interactive=True)
if code == display_util.OK and items:
return items
return []
def choose_names(installer, question=None):
def choose_names(installer: Optional[interfaces.Installer],
question: Optional[str] = None) -> List[str]:
"""Display screen to select domains to validate.
:param installer: An installer object
@ -125,7 +137,7 @@ def choose_names(installer, question=None):
return []
def get_valid_domains(domains):
def get_valid_domains(domains: Iterable[str]) -> List[str]:
"""Helper method for choose_names that implements basic checks
on domain names
@ -133,7 +145,7 @@ def get_valid_domains(domains):
:return: List of valid domains
:rtype: list
"""
valid_domains = []
valid_domains: List[str] = []
for domain in domains:
try:
valid_domains.append(util.enforce_domain_sanity(domain))
@ -142,7 +154,7 @@ def get_valid_domains(domains):
return valid_domains
def _sort_names(FQDNs):
def _sort_names(FQDNs: Iterable[str]) -> List[str]:
"""Sort FQDNs by SLD (and if many, by their subdomains)
:param list FQDNs: list of domain names
@ -153,7 +165,8 @@ def _sort_names(FQDNs):
return sorted(FQDNs, key=lambda fqdn: fqdn.split('.')[::-1][1:])
def _filter_names(names, override_question=None):
def _filter_names(names: Iterable[str],
override_question: Optional[str] = None) -> Tuple[str, List[str]]:
"""Determine which names the user would like to select from a list.
:param list names: domain names
@ -175,7 +188,7 @@ def _filter_names(names, override_question=None):
return code, [str(s) for s in names]
def _choose_names_manually(prompt_prefix=""):
def _choose_names_manually(prompt_prefix: str = "") -> List[str]:
"""Manually input names for those without an installer.
:param str prompt_prefix: string to prepend to prompt for domains
@ -229,7 +242,7 @@ def _choose_names_manually(prompt_prefix=""):
return []
def success_installation(domains):
def success_installation(domains: Sequence[str]) -> None:
"""Display a box confirming the installation of HTTPS.
:param list domains: domain names which were enabled
@ -241,7 +254,7 @@ def success_installation(domains):
)
def success_renewal(unused_domains):
def success_renewal(unused_domains: Sequence[str]) -> None:
"""Display a box confirming the renewal of an existing certificate.
:param list domains: domain names which were renewed
@ -253,7 +266,7 @@ def success_renewal(unused_domains):
)
def success_revocation(cert_path):
def success_revocation(cert_path: str) -> None:
"""Display a message confirming a certificate has been revoked.
:param list cert_path: path to certificate which was revoked.
@ -283,7 +296,7 @@ def report_executed_command(command_name: str, returncode: int, stdout: str, std
logger.warning("%s ran with error output:\n%s", command_name, indent(err_s, ' '))
def _gen_https_names(domains):
def _gen_https_names(domains: Sequence[str]) -> str:
"""Returns a string of the https domains.
Domains are formatted nicely with ``https://`` prepended to each.
@ -304,7 +317,9 @@ def _gen_https_names(domains):
return ""
def _get_validated(method, validator, message, default=None, **kwargs):
def _get_validated(method: Callable[..., Tuple[str, str]],
validator: Callable[[str], Any], message: str,
default: Optional[str] = None, **kwargs: Any) -> Tuple[str, str]:
if default is not None:
try:
validator(default)
@ -331,7 +346,8 @@ def _get_validated(method, validator, message, default=None, **kwargs):
return code, raw
def validated_input(validator, *args, **kwargs):
def validated_input(validator: Callable[[str], Any],
*args: Any, **kwargs: Any) -> Tuple[str, str]:
"""Like `~certbot.display.util.input_text`, but with validation.
:param callable validator: A method which will be called on the
@ -345,7 +361,8 @@ def validated_input(validator, *args, **kwargs):
return _get_validated(display_util.input_text, validator, *args, **kwargs)
def validated_directory(validator, *args, **kwargs):
def validated_directory(validator: Callable[[str], Any],
*args: Any, **kwargs: Any) -> Tuple[str, str]:
"""Like `~certbot.display.util.directory_select`, but with validation.
:param callable validator: A method which will be called on the

View file

@ -11,6 +11,7 @@ Other messages can use the `logging` module. See `log.py`.
"""
import sys
from types import ModuleType
from typing import Any
from typing import cast
from typing import List
from typing import Optional
@ -18,7 +19,7 @@ from typing import Tuple
from typing import Union
import warnings
from certbot._internal.display import obj
# These specific imports from certbot._internal.display.obj and
# certbot._internal.display.util are done to not break the public API of this
# module.
@ -28,8 +29,6 @@ from certbot._internal.display.obj import SIDE_FRAME # pylint: disable=unused-i
from certbot._internal.display.util import input_with_timeout # pylint: disable=unused-import
from certbot._internal.display.util import separate_list_input # pylint: disable=unused-import
from certbot._internal.display.util import summarize_domain_list # pylint: disable=unused-import
from certbot._internal.display import obj
# These constants are defined this way to make them easier to document with
# Sphinx and to not couple our public docstrings to our internal ones.
@ -77,7 +76,7 @@ def notification(message: str, pause: bool = True, wrap: bool = True,
force_interactive=force_interactive, decorate=decorate)
def menu(message: str, choices: Union[List[str], Tuple[str, str]],
def menu(message: str, choices: Union[List[str], List[Tuple[str, str]]],
default: Optional[int] = None, cli_flag: Optional[str] = None,
force_interactive: bool = False) -> Tuple[str, int]:
"""Display a menu.
@ -89,7 +88,7 @@ def menu(message: str, choices: Union[List[str], Tuple[str, str]],
:param choices: Menu lines, len must be > 0
:type choices: list of tuples (tag, item) or
list of descriptions (tags will be enumerated)
:param default: default value to return (if one exists)
:param default: default value to return, if interaction is not possible
:param str cli_flag: option used to set this value with the CLI
:param bool force_interactive: True if it's safe to prompt the user
because it won't cause any workflow regressions
@ -110,7 +109,7 @@ def input_text(message: str, default: Optional[str] = None, cli_flag: Optional[s
"""Accept input from the user.
:param str message: message to display to the user
:param default: default value to return (if one exists)
:param default: default value to return, if interaction is not possible
:param str cli_flag: option used to set this value with the CLI
:param bool force_interactive: True if it's safe to prompt the user
because it won't cause any workflow regressions
@ -136,7 +135,7 @@ def yesno(message: str, yes_label: str = "Yes", no_label: str = "No",
:param str message: question for the user
:param str yes_label: Label of the "Yes" parameter
:param str no_label: Label of the "No" parameter
:param default: default value to return (if one exists)
:param default: default value to return, if interaction is not possible
:param str cli_flag: option used to set this value with the CLI
:param bool force_interactive: True if it's safe to prompt the user
because it won't cause any workflow regressions
@ -149,14 +148,14 @@ def yesno(message: str, yes_label: str = "Yes", no_label: str = "No",
cli_flag=cli_flag, force_interactive=force_interactive)
def checklist(message: str, tags: List[str], default: Optional[str] = None,
def checklist(message: str, tags: List[str], default: Optional[List[str]] = None,
cli_flag: Optional[str] = None,
force_interactive: bool = False) -> Tuple[str, List[str]]:
"""Display a checklist.
:param str message: Message to display to user
:param list tags: `str` tags to select, len(tags) > 0
:param default: default value to return (if one exists)
:param default: default value to return, if interaction is not possible
:param str cli_flag: option used to set this value with the CLI
:param bool force_interactive: True if it's safe to prompt the user
because it won't cause any workflow regressions
@ -172,11 +171,11 @@ def checklist(message: str, tags: List[str], default: Optional[str] = None,
def directory_select(message: str, default: Optional[str] = None, cli_flag: Optional[str] = None,
force_interactive: bool = False) -> Tuple[int, str]:
force_interactive: bool = False) -> Tuple[str, str]:
"""Display a directory selection screen.
:param str message: prompt to give the user
:param default: default value to return (if one exists)
:param default: default value to return, if interaction is not possible
:param str cli_flag: option used to set this value with the CLI
:param bool force_interactive: True if it's safe to prompt the user
because it won't cause any workflow regressions
@ -190,7 +189,7 @@ def directory_select(message: str, default: Optional[str] = None, cli_flag: Opti
force_interactive=force_interactive)
def assert_valid_call(prompt, default, cli_flag, force_interactive):
def assert_valid_call(prompt: str, default: str, cli_flag: str, force_interactive: bool) -> None:
"""Verify that provided arguments is a valid display call.
:param str prompt: prompt for the user
@ -215,10 +214,10 @@ class _DisplayUtilDeprecationModule:
Internal class delegating to a module, and displaying warnings when attributes
related to deprecated attributes in the certbot.display.util module.
"""
def __init__(self, module):
def __init__(self, module: ModuleType) -> None:
self.__dict__['_module'] = module
def __getattr__(self, attr):
def __getattr__(self, attr: str) -> Any:
if attr in ('FileDisplay', 'NoninteractiveDisplay', 'SIDE_FRAME', 'input_with_timeout',
'separate_list_input', 'summarize_domain_list', 'WIDTH', 'HELP', 'ESC'):
warnings.warn('{0} attribute in certbot.display.util module is deprecated '
@ -226,13 +225,13 @@ class _DisplayUtilDeprecationModule:
DeprecationWarning, stacklevel=2)
return getattr(self._module, attr)
def __setattr__(self, attr, value): # pragma: no cover
def __setattr__(self, attr: str, value: Any) -> None: # pragma: no cover
setattr(self._module, attr, value)
def __delattr__(self, attr): # pragma: no cover
def __delattr__(self, attr: str) -> None: # pragma: no cover
delattr(self._module, attr)
def __dir__(self): # pragma: no cover
def __dir__(self) -> List[str]: # pragma: no cover
return ['_module'] + dir(self._module)

View file

@ -1,16 +1,25 @@
"""Plugin common functions."""
from abc import ABCMeta
from abc import abstractmethod
import argparse
import logging
import re
import shutil
import tempfile
from typing import Any
from typing import Callable
from typing import Iterable
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
import pkg_resources
from certbot import achallenges
from certbot import configuration
from certbot import crypto_util
from certbot import interfaces
from certbot import errors
from certbot import reverter
from certbot._internal import constants
@ -23,12 +32,12 @@ from certbot.plugins.storage import PluginStorage
logger = logging.getLogger(__name__)
def option_namespace(name):
def option_namespace(name: str) -> str:
"""ArgumentParser options namespace (prefix of all options)."""
return name + "-"
def dest_namespace(name):
def dest_namespace(name: str) -> str:
"""ArgumentParser dest namespace (prefix of all destinations)."""
return name.replace("-", "_") + "_"
@ -43,14 +52,14 @@ hostname_regex = re.compile(
class Plugin(AbstractPlugin, metaclass=ABCMeta):
"""Generic plugin."""
def __init__(self, config, name):
def __init__(self, config: configuration.NamespaceConfig, name: str) -> None:
super().__init__(config, name)
self.config = config
self.name = name
@classmethod
@abstractmethod
def add_parser_arguments(cls, add):
def add_parser_arguments(cls, add: Callable[..., None]) -> None:
"""Add plugin arguments to the CLI argument parser.
:param callable add: Function that proxies calls to
@ -60,40 +69,40 @@ class Plugin(AbstractPlugin, metaclass=ABCMeta):
"""
@classmethod
def inject_parser_options(cls, parser, name):
def inject_parser_options(cls, parser: argparse.ArgumentParser, name: str) -> None:
"""Inject parser options.
See `~.certbot.interfaces.Plugin.inject_parser_options` for docs.
"""
# dummy function, doesn't check if dest.startswith(self.dest_namespace)
def add(arg_name_no_prefix, *args, **kwargs):
return parser.add_argument(
def add(arg_name_no_prefix: str, *args: Any, **kwargs: Any) -> None:
parser.add_argument(
"--{0}{1}".format(option_namespace(name), arg_name_no_prefix),
*args, **kwargs)
return cls.add_parser_arguments(add)
@property
def option_namespace(self):
def option_namespace(self) -> str:
"""ArgumentParser options namespace (prefix of all options)."""
return option_namespace(self.name)
def option_name(self, name):
def option_name(self, name: str) -> str:
"""Option name (include plugin namespace)."""
return self.option_namespace + name
@property
def dest_namespace(self):
def dest_namespace(self) -> str:
"""ArgumentParser dest namespace (prefix of all destinations)."""
return dest_namespace(self.name)
def dest(self, var):
def dest(self, var: str) -> str:
"""Find a destination for given variable ``var``."""
# this should do exactly the same what ArgumentParser(arg),
# does to "arg" to compute "dest"
return self.dest_namespace + var.replace("-", "_")
def conf(self, var):
def conf(self, var: str) -> Any:
"""Find a configuration value for variable ``var``."""
return getattr(self.config, self.dest(var))
@ -130,12 +139,13 @@ class Installer(AbstractInstaller, Plugin, metaclass=ABCMeta):
Installer plugins do not have to inherit from this class.
"""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.storage = PluginStorage(self.config, self.name)
self.reverter = reverter.Reverter(self.config)
def add_to_checkpoint(self, save_files, save_notes, temporary=False):
def add_to_checkpoint(self, save_files: Set[str], save_notes: str,
temporary: bool = False) -> None:
"""Add files to a checkpoint.
:param set save_files: set of filepaths to save
@ -157,7 +167,7 @@ class Installer(AbstractInstaller, Plugin, metaclass=ABCMeta):
except errors.ReverterError as err:
raise errors.PluginError(str(err))
def finalize_checkpoint(self, title):
def finalize_checkpoint(self, title: str) -> None:
"""Timestamp and save changes made through the reverter.
:param str title: Title describing checkpoint
@ -170,7 +180,7 @@ class Installer(AbstractInstaller, Plugin, metaclass=ABCMeta):
except errors.ReverterError as err:
raise errors.PluginError(str(err))
def recovery_routine(self):
def recovery_routine(self) -> None:
"""Revert all previously modified files.
Reverts all modified files that have not been saved as a checkpoint
@ -183,7 +193,7 @@ class Installer(AbstractInstaller, Plugin, metaclass=ABCMeta):
except errors.ReverterError as err:
raise errors.PluginError(str(err))
def revert_temporary_config(self):
def revert_temporary_config(self) -> None:
"""Rollback temporary checkpoint.
:raises .errors.PluginError: when unable to revert config
@ -194,7 +204,7 @@ class Installer(AbstractInstaller, Plugin, metaclass=ABCMeta):
except errors.ReverterError as err:
raise errors.PluginError(str(err))
def rollback_checkpoints(self, rollback=1):
def rollback_checkpoints(self, rollback: int = 1) -> None:
"""Rollback saved checkpoints.
:param int rollback: Number of checkpoints to revert
@ -209,24 +219,31 @@ class Installer(AbstractInstaller, Plugin, metaclass=ABCMeta):
raise errors.PluginError(str(err))
@property
def ssl_dhparams(self):
def ssl_dhparams(self) -> str:
"""Full absolute path to ssl_dhparams file."""
return os.path.join(self.config.config_dir, constants.SSL_DHPARAMS_DEST)
@property
def updated_ssl_dhparams_digest(self):
def updated_ssl_dhparams_digest(self) -> str:
"""Full absolute path to digest of updated ssl_dhparams file."""
return os.path.join(self.config.config_dir, constants.UPDATED_SSL_DHPARAMS_DIGEST)
def install_ssl_dhparams(self):
def install_ssl_dhparams(self) -> None:
"""Copy Certbot's ssl_dhparams file into the system's config dir if required."""
return install_version_controlled_file(
install_version_controlled_file(
self.ssl_dhparams,
self.updated_ssl_dhparams_digest,
constants.SSL_DHPARAMS_SRC,
constants.ALL_SSL_DHPARAMS_HASHES)
class Configurator(Installer, interfaces.Authenticator, metaclass=ABCMeta):
"""
A plugin that extends certbot.plugins.common.Installer
and implements certbot.interfaces.Authenticator
"""
class Addr:
r"""Represents an virtual host address.
@ -234,12 +251,12 @@ class Addr:
:param str port: port number or \*, or ""
"""
def __init__(self, tup, ipv6=False):
def __init__(self, tup: Tuple[str, str], ipv6: bool = False):
self.tup = tup
self.ipv6 = ipv6
@classmethod
def fromstring(cls, str_addr):
def fromstring(cls, str_addr: str) -> 'Addr':
"""Initialize Addr from string."""
if str_addr.startswith('['):
# ipv6 addresses starts with [
@ -253,19 +270,19 @@ class Addr:
tup = str_addr.partition(':')
return cls((tup[0], tup[2]))
def __str__(self):
def __str__(self) -> str:
if self.tup[1]:
return "%s:%s" % self.tup
return self.tup[0]
def normalized_tuple(self):
def normalized_tuple(self) -> Tuple[str, str]:
"""Normalized representation of addr/port tuple
"""
if self.ipv6:
return (self.get_ipv6_exploded(), self.tup[1])
return self.get_ipv6_exploded(), self.tup[1]
return self.tup
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if isinstance(other, self.__class__):
# compare normalized to take different
# styles of representation into account
@ -273,34 +290,34 @@ class Addr:
return False
def __hash__(self):
def __hash__(self) -> int:
return hash(self.tup)
def get_addr(self):
def get_addr(self) -> str:
"""Return addr part of Addr object."""
return self.tup[0]
def get_port(self):
def get_port(self) -> str:
"""Return port."""
return self.tup[1]
def get_addr_obj(self, port):
def get_addr_obj(self, port: str) -> 'Addr':
"""Return new address object with same addr and new port."""
return self.__class__((self.tup[0], port), self.ipv6)
def _normalize_ipv6(self, addr):
def _normalize_ipv6(self, addr: str) -> List[str]:
"""Return IPv6 address in normalized form, helper function"""
addr = addr.lstrip("[")
addr = addr.rstrip("]")
return self._explode_ipv6(addr)
def get_ipv6_exploded(self):
def get_ipv6_exploded(self) -> str:
"""Return IPv6 in normalized form"""
if self.ipv6:
return ":".join(self._normalize_ipv6(self.tup[0]))
return ""
def _explode_ipv6(self, addr):
def _explode_ipv6(self, addr: str) -> List[str]:
"""Explode IPv6 address for comparison"""
result = ['0', '0', '0', '0', '0', '0', '0', '0']
addr_list = addr.split(":")
@ -337,12 +354,13 @@ class ChallengePerformer:
"""
def __init__(self, configurator):
def __init__(self, configurator: Configurator):
self.configurator = configurator
self.achalls: List[achallenges.KeyAuthorizationAnnotatedChallenge] = []
self.indices: List[int] = []
def add_chall(self, achall, idx=None):
def add_chall(self, achall: achallenges.KeyAuthorizationAnnotatedChallenge,
idx: Optional[int] = None) -> None:
"""Store challenge to be performed when perform() is called.
:param .KeyAuthorizationAnnotatedChallenge achall: Annotated
@ -354,7 +372,7 @@ class ChallengePerformer:
if idx is not None:
self.indices.append(idx)
def perform(self):
def perform(self) -> List[achallenges.KeyAuthorizationAnnotatedChallenge]:
"""Perform all added challenges.
:returns: challenge responses
@ -365,7 +383,8 @@ class ChallengePerformer:
raise NotImplementedError()
def install_version_controlled_file(dest_path, digest_path, src_path, all_hashes):
def install_version_controlled_file(dest_path: str, digest_path: str, src_path: str,
all_hashes: Iterable[str]) -> None:
"""Copy a file into an active location (likely the system's config dir) if required.
:param str dest_path: destination path for version controlled file
@ -375,11 +394,11 @@ def install_version_controlled_file(dest_path, digest_path, src_path, all_hashes
"""
current_hash = crypto_util.sha256sum(src_path)
def _write_current_hash():
with open(digest_path, "w") as f:
f.write(current_hash)
def _write_current_hash() -> None:
with open(digest_path, "w") as file_h:
file_h.write(current_hash)
def _install_current_file():
def _install_current_file() -> None:
shutil.copyfile(src_path, dest_path)
_write_current_hash()
@ -415,9 +434,9 @@ def install_version_controlled_file(dest_path, digest_path, src_path, all_hashes
# "pragma: no cover") TODO: this might quickly lead to dead code (also
# c.f. #383)
def dir_setup(test_dir, pkg): # pragma: no cover
def dir_setup(test_dir: str, pkg: str) -> Tuple[str, str, str]: # pragma: no cover
"""Setup the directories necessary for the configurator."""
def expanded_tempdir(prefix):
def expanded_tempdir(prefix: str) -> str:
"""Return the real path of a temp directory with the specified prefix
Some plugins rely on real paths of symlinks for working correctly. For

View file

@ -1,12 +1,19 @@
"""Common code for DNS Authenticator Plugins."""
import abc
import logging
from time import sleep
from typing import Callable
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
from typing import Type
import configobj
from acme import challenges
from certbot import achallenges
from certbot import configuration
from certbot import errors
from certbot import interfaces
from certbot.compat import filesystem
@ -21,20 +28,21 @@ logger = logging.getLogger(__name__)
class DNSAuthenticator(common.Plugin, interfaces.Authenticator, metaclass=abc.ABCMeta):
"""Base class for DNS Authenticators"""
def __init__(self, config, name):
def __init__(self, config: configuration.NamespaceConfig, name: str) -> None:
super().__init__(config, name)
self._attempt_cleanup = False
@classmethod
def add_parser_arguments(cls, add, default_propagation_seconds=10): # pylint: disable=arguments-differ
def add_parser_arguments(cls, add: Callable[..., None], # pylint: disable=arguments-differ
default_propagation_seconds: int = 10) -> None:
add('propagation-seconds',
default=default_propagation_seconds,
type=int,
help='The number of seconds to wait for DNS to propagate before asking the ACME server '
'to verify the DNS record.')
def auth_hint(self, failed_achalls):
def auth_hint(self, failed_achalls: List[achallenges.AnnotatedChallenge]) -> str:
"""See certbot.plugins.common.Plugin.auth_hint."""
delay = self.conf('propagation-seconds')
return (
@ -44,16 +52,17 @@ class DNSAuthenticator(common.Plugin, interfaces.Authenticator, metaclass=abc.AB
.format(name=self.name, secs=delay, suffix='s' if delay != 1 else '')
)
def get_chall_pref(self, unused_domain): # pylint: disable=missing-function-docstring
def get_chall_pref(self, unused_domain: str) -> Iterable[Type[challenges.Challenge]]: # pylint: disable=missing-function-docstring
return [challenges.DNS01]
def prepare(self): # pylint: disable=missing-function-docstring
def prepare(self) -> None: # pylint: disable=missing-function-docstring
pass
def more_info(self) -> str: # pylint: disable=missing-function-docstring
raise NotImplementedError()
def perform(self, achalls): # pylint: disable=missing-function-docstring
def perform(self, achalls: List[achallenges.AnnotatedChallenge]
) -> List[challenges.ChallengeResponse]: # pylint: disable=missing-function-docstring
self._setup_credentials()
self._attempt_cleanup = True
@ -76,7 +85,7 @@ class DNSAuthenticator(common.Plugin, interfaces.Authenticator, metaclass=abc.AB
return responses
def cleanup(self, achalls): # pylint: disable=missing-function-docstring
def cleanup(self, achalls: List[achallenges.AnnotatedChallenge]) -> None: # pylint: disable=missing-function-docstring
if self._attempt_cleanup:
for achall in achalls:
domain = achall.domain
@ -86,14 +95,15 @@ class DNSAuthenticator(common.Plugin, interfaces.Authenticator, metaclass=abc.AB
self._cleanup(domain, validation_domain_name, validation)
@abc.abstractmethod
def _setup_credentials(self): # pragma: no cover
def _setup_credentials(self) -> None: # pragma: no cover
"""
Establish credentials, prompting if necessary.
"""
raise NotImplementedError()
@abc.abstractmethod
def _perform(self, domain, validation_name, validation): # pragma: no cover
def _perform(self, domain: str, validation_name: str,
validation: str) -> None: # pragma: no cover
"""
Performs a dns-01 challenge by creating a DNS TXT record.
@ -105,7 +115,8 @@ class DNSAuthenticator(common.Plugin, interfaces.Authenticator, metaclass=abc.AB
raise NotImplementedError()
@abc.abstractmethod
def _cleanup(self, domain, validation_name, validation): # pragma: no cover
def _cleanup(self, domain: str, validation_name: str,
validation: str) -> None: # pragma: no cover
"""
Deletes the DNS TXT record which would have been created by `_perform_achall`.
@ -117,7 +128,7 @@ class DNSAuthenticator(common.Plugin, interfaces.Authenticator, metaclass=abc.AB
"""
raise NotImplementedError()
def _configure(self, key, label):
def _configure(self, key: str, label: str) -> None:
"""
Ensure that a configuration value is available.
@ -133,7 +144,8 @@ class DNSAuthenticator(common.Plugin, interfaces.Authenticator, metaclass=abc.AB
setattr(self.config, self.dest(key), new_value)
def _configure_file(self, key, label, validator=None):
def _configure_file(self, key: str, label: str,
validator: Optional[Callable[[str], None]] = None) -> None:
"""
Ensure that a configuration value is available for a path.
@ -149,8 +161,10 @@ class DNSAuthenticator(common.Plugin, interfaces.Authenticator, metaclass=abc.AB
setattr(self.config, self.dest(key), os.path.abspath(os.path.expanduser(new_value)))
def _configure_credentials(self, key, label, required_variables=None,
validator=None) -> 'CredentialsConfiguration':
def _configure_credentials(
self, key: str, label: str, required_variables: Optional[Mapping[str, str]] = None,
validator: Optional[Callable[['CredentialsConfiguration'], None]] = None
) -> 'CredentialsConfiguration':
"""
As `_configure_file`, but for a credential configuration file.
@ -167,14 +181,14 @@ class DNSAuthenticator(common.Plugin, interfaces.Authenticator, metaclass=abc.AB
indicate any issue.
"""
def __validator(filename): # pylint: disable=unused-private-member
configuration = CredentialsConfiguration(filename, self.dest)
def __validator(filename: str) -> None: # pylint: disable=unused-private-member
applied_configuration = CredentialsConfiguration(filename, self.dest)
if required_variables:
configuration.require(required_variables)
applied_configuration.require(required_variables)
if validator:
validator(configuration)
validator(applied_configuration)
self._configure_file(key, label, __validator)
@ -188,7 +202,7 @@ class DNSAuthenticator(common.Plugin, interfaces.Authenticator, metaclass=abc.AB
return credentials_configuration
@staticmethod
def _prompt_for_data(label):
def _prompt_for_data(label: str) -> str:
"""
Prompt the user for a piece of information.
@ -197,7 +211,7 @@ class DNSAuthenticator(common.Plugin, interfaces.Authenticator, metaclass=abc.AB
:rtype: str
"""
def __validator(i): # pylint: disable=unused-private-member
def __validator(i: str) -> None: # pylint: disable=unused-private-member
if not i:
raise errors.PluginError('Please enter your {0}.'.format(label))
@ -211,7 +225,7 @@ class DNSAuthenticator(common.Plugin, interfaces.Authenticator, metaclass=abc.AB
raise errors.PluginError('{0} required to proceed.'.format(label))
@staticmethod
def _prompt_for_file(label, validator=None):
def _prompt_for_file(label: str, validator: Optional[Callable[[str], None]] = None) -> str:
"""
Prompt the user for a path.
@ -223,7 +237,7 @@ class DNSAuthenticator(common.Plugin, interfaces.Authenticator, metaclass=abc.AB
:rtype: str
"""
def __validator(filename): # pylint: disable=unused-private-member
def __validator(filename: str) -> None: # pylint: disable=unused-private-member
if not filename:
raise errors.PluginError('Please enter a valid path to your {0}.'.format(label))
@ -247,7 +261,7 @@ class DNSAuthenticator(common.Plugin, interfaces.Authenticator, metaclass=abc.AB
class CredentialsConfiguration:
"""Represents a user-supplied filed which stores API credentials."""
def __init__(self, filename, mapper=lambda x: x):
def __init__(self, filename: str, mapper: Callable[[str], str] = lambda x: x) -> None:
"""
:param str filename: A path to the configuration file.
:param callable mapper: A transformation to apply to configuration key names
@ -263,7 +277,7 @@ class CredentialsConfiguration:
self.mapper = mapper
def require(self, required_variables):
def require(self, required_variables: Mapping[str, str]) -> None:
"""Ensures that the supplied set of variables are all present in the file.
:param dict required_variables: Map of variable which must be present to error to display.
@ -288,7 +302,7 @@ class CredentialsConfiguration:
)
)
def conf(self, var):
def conf(self, var: str) -> str:
"""Find a configuration value for variable `var`, as transformed by `mapper`.
:param str var: The variable to get.
@ -298,14 +312,14 @@ class CredentialsConfiguration:
return self._get(var)
def _has(self, var):
def _has(self, var: str) -> bool:
return self.mapper(var) in self.confobj
def _get(self, var):
def _get(self, var: str) -> str:
return self.confobj.get(self.mapper(var))
def validate_file(filename):
def validate_file(filename: str) -> None:
"""Ensure that the specified file exists."""
if not os.path.exists(filename):
@ -315,7 +329,7 @@ def validate_file(filename):
raise errors.PluginError('Path is a directory: {0}'.format(filename))
def validate_file_permissions(filename):
def validate_file_permissions(filename: str) -> None:
"""Ensure that the specified file exists and warn about unsafe permissions."""
validate_file(filename)
@ -324,7 +338,7 @@ def validate_file_permissions(filename):
logger.warning('Unsafe permissions on credentials configuration file: %s', filename)
def base_domain_name_guesses(domain):
def base_domain_name_guesses(domain: str) -> List[str]:
"""Return a list of progressively less-specific domain names.
One of these will probably be the domain name known to the DNS provider.

View file

@ -2,6 +2,8 @@
import logging
from typing import Any
from typing import Dict
from typing import Mapping
from typing import Optional
from typing import Union
from requests.exceptions import HTTPError
@ -30,10 +32,10 @@ class LexiconClient:
Encapsulates all communication with a DNS provider via Lexicon.
"""
def __init__(self):
def __init__(self) -> None:
self.provider: Provider
def add_txt_record(self, domain, record_name, record_content):
def add_txt_record(self, domain: str, record_name: str, record_content: str) -> None:
"""
Add a TXT record using the supplied information.
@ -50,7 +52,7 @@ class LexiconClient:
logger.debug('Encountered error adding TXT record: %s', e, exc_info=True)
raise errors.PluginError('Error adding TXT record: {0}'.format(e))
def del_txt_record(self, domain, record_name, record_content):
def del_txt_record(self, domain: str, record_name: str, record_content: str) -> None:
"""
Delete a TXT record using the supplied information.
@ -71,7 +73,7 @@ class LexiconClient:
except RequestException as e:
logger.debug('Encountered error deleting TXT record: %s', e, exc_info=True)
def _find_domain_id(self, domain):
def _find_domain_id(self, domain: str) -> None:
"""
Find the domain_id for a given domain.
@ -94,24 +96,24 @@ class LexiconClient:
return # If `authenticate` doesn't throw an exception, we've found the right name
except HTTPError as e:
result = self._handle_http_error(e, domain_name)
result1 = self._handle_http_error(e, domain_name)
if result:
raise result
if result1:
raise result1
except Exception as e: # pylint: disable=broad-except
result = self._handle_general_error(e, domain_name)
result2 = self._handle_general_error(e, domain_name)
if result:
raise result # pylint: disable=raising-bad-type
if result2:
raise result2 # pylint: disable=raising-bad-type
raise errors.PluginError('Unable to determine zone identifier for {0} using zone names: {1}'
.format(domain, domain_name_guesses))
def _handle_http_error(self, e, domain_name):
def _handle_http_error(self, e: HTTPError, domain_name: str) -> errors.PluginError:
return errors.PluginError('Error determining zone identifier for {0}: {1}.'
.format(domain_name, e))
def _handle_general_error(self, e, domain_name):
def _handle_general_error(self, e: Exception, domain_name: str) -> Optional[errors.PluginError]:
if not str(e).startswith('No domain found'):
return errors.PluginError('Unexpected error determining zone identifier for {0}: {1}'
.format(domain_name, e))
@ -119,8 +121,8 @@ class LexiconClient:
def build_lexicon_config(lexicon_provider_name: str,
lexicon_options: Dict, provider_options: Dict
) -> Union[ConfigResolver, Dict]:
lexicon_options: Mapping[str, Any], provider_options: Mapping[str, Any]
) -> Union[ConfigResolver, Dict[str, Any]]:
"""
Convenient function to build a Lexicon 2.x/3.x config object.
:param str lexicon_provider_name: the name of the lexicon provider to use
@ -129,14 +131,14 @@ def build_lexicon_config(lexicon_provider_name: str,
:return: configuration to apply to the provider
:rtype: ConfigurationResolver or dict
"""
config: Dict[str, Any] = {'provider_name': lexicon_provider_name}
config: Union[ConfigResolver, Dict[str, Any]] = {'provider_name': lexicon_provider_name}
config.update(lexicon_options)
if not ConfigResolver:
# Lexicon 2.x
config.update(provider_options)
else:
# Lexicon 3.x
provider_config = {}
provider_config: Dict[str, Any] = {}
provider_config.update(provider_options)
config[lexicon_provider_name] = provider_config
config = ConfigResolver().with_dict(config).with_env()

View file

@ -1,5 +1,7 @@
"""Base test class for DNS authenticators."""
import typing
from typing import Any
from typing import Mapping
from typing import TYPE_CHECKING
import configobj
import josepy as jose
@ -11,13 +13,12 @@ from certbot.plugins.dns_common import DNSAuthenticator
from certbot.tests import acme_util
from certbot.tests import util as test_util
if typing.TYPE_CHECKING:
if TYPE_CHECKING:
from typing_extensions import Protocol
else:
Protocol = object
try:
import mock
except ImportError: # pragma: no cover
@ -32,14 +33,14 @@ class _AuthenticatorCallableTestCase(Protocol):
"""Protocol describing a TestCase able to call a real DNSAuthenticator instance."""
auth: DNSAuthenticator
def assertTrue(self, *unused_args) -> None:
def assertTrue(self, *unused_args: Any) -> None:
"""
See
https://docs.python.org/3/library/unittest.html#unittest.TestCase.assertTrue
"""
...
def assertEqual(self, *unused_args) -> None:
def assertEqual(self, *unused_args: Any) -> None:
"""
See
https://docs.python.org/3/library/unittest.html#unittest.TestCase.assertEqual
@ -59,20 +60,20 @@ class BaseAuthenticatorTest:
achall = achallenges.KeyAuthorizationAnnotatedChallenge(
challb=acme_util.DNS01, domain=DOMAIN, account_key=KEY)
def test_more_info(self: _AuthenticatorCallableTestCase):
def test_more_info(self: _AuthenticatorCallableTestCase) -> None:
self.assertTrue(isinstance(self.auth.more_info(), str)) # pylint: disable=no-member
def test_get_chall_pref(self: _AuthenticatorCallableTestCase):
self.assertEqual(self.auth.get_chall_pref(None), [challenges.DNS01]) # pylint: disable=no-member
def test_get_chall_pref(self: _AuthenticatorCallableTestCase) -> None:
self.assertEqual(self.auth.get_chall_pref("example.org"), [challenges.DNS01]) # pylint: disable=no-member
def test_parser_arguments(self: _AuthenticatorCallableTestCase):
def test_parser_arguments(self: _AuthenticatorCallableTestCase) -> None:
m = mock.MagicMock()
self.auth.add_parser_arguments(m) # pylint: disable=no-member
m.assert_any_call('propagation-seconds', type=int, default=mock.ANY, help=mock.ANY)
def write(values, path):
def write(values: Mapping[str, Any], path: str) -> None:
"""Write the specified values to a config file.
:param dict values: A map of values to write.

View file

@ -1,5 +1,6 @@
"""Base test class for DNS authenticators built on Lexicon."""
import typing
from typing import Any
from typing import TYPE_CHECKING
from unittest.mock import MagicMock
import josepy as jose
@ -17,14 +18,11 @@ try:
import mock
except ImportError: # pragma: no cover
from unittest import mock # type: ignore
if typing.TYPE_CHECKING:
if TYPE_CHECKING:
from typing_extensions import Protocol
else:
Protocol = object
DOMAIN = 'example.com'
KEY = jose.JWKRSA.load(test_util.load_vector("rsa512_key.pem"))
@ -54,7 +52,7 @@ class _LexiconAwareTestCase(Protocol):
LOGIN_ERROR: Exception
UNKNOWN_LOGIN_ERROR: Exception
def assertRaises(self, *unused_args) -> None:
def assertRaises(self, *unused_args: Any) -> None:
"""
See
https://docs.python.org/3/library/unittest.html#unittest.TestCase.assertRaises
@ -68,13 +66,14 @@ class _LexiconAwareTestCase(Protocol):
class BaseLexiconAuthenticatorTest(dns_test_common.BaseAuthenticatorTest):
@test_util.patch_display_util()
def test_perform(self: _AuthenticatorCallableLexiconTestCase, unused_mock_get_utility):
def test_perform(self: _AuthenticatorCallableLexiconTestCase,
unused_mock_get_utility: Any) -> None:
self.auth.perform([self.achall])
expected = [mock.call.add_txt_record(DOMAIN, '_acme-challenge.'+DOMAIN, mock.ANY)]
self.assertEqual(expected, self.mock_client.mock_calls)
def test_cleanup(self: _AuthenticatorCallableLexiconTestCase):
def test_cleanup(self: _AuthenticatorCallableLexiconTestCase) -> None:
self.auth._attempt_cleanup = True # _attempt_cleanup | pylint: disable=protected-access
self.auth.cleanup([self.achall])
@ -92,14 +91,14 @@ class BaseLexiconClientTest:
record_name = record_prefix + "." + DOMAIN
record_content = "bar"
def test_add_txt_record(self: _LexiconAwareTestCase):
def test_add_txt_record(self: _LexiconAwareTestCase) -> None:
self.client.add_txt_record(DOMAIN, self.record_name, self.record_content)
self.provider_mock.create_record.assert_called_with(rtype='TXT',
name=self.record_name,
content=self.record_content)
def test_add_txt_record_try_twice_to_find_domain(self: _LexiconAwareTestCase):
def test_add_txt_record_try_twice_to_find_domain(self: _LexiconAwareTestCase) -> None:
self.provider_mock.authenticate.side_effect = [self.DOMAIN_NOT_FOUND, '']
self.client.add_txt_record(DOMAIN, self.record_name, self.record_content)
@ -108,7 +107,7 @@ class BaseLexiconClientTest:
name=self.record_name,
content=self.record_content)
def test_add_txt_record_fail_to_find_domain(self: _LexiconAwareTestCase):
def test_add_txt_record_fail_to_find_domain(self: _LexiconAwareTestCase) -> None:
self.provider_mock.authenticate.side_effect = [self.DOMAIN_NOT_FOUND,
self.DOMAIN_NOT_FOUND,
self.DOMAIN_NOT_FOUND,]
@ -117,64 +116,66 @@ class BaseLexiconClientTest:
self.client.add_txt_record,
DOMAIN, self.record_name, self.record_content)
def test_add_txt_record_fail_to_authenticate(self: _LexiconAwareTestCase):
def test_add_txt_record_fail_to_authenticate(self: _LexiconAwareTestCase) -> None:
self.provider_mock.authenticate.side_effect = self.LOGIN_ERROR
self.assertRaises(errors.PluginError,
self.client.add_txt_record,
DOMAIN, self.record_name, self.record_content)
def test_add_txt_record_fail_to_authenticate_with_unknown_error(self: _LexiconAwareTestCase):
def test_add_txt_record_fail_to_authenticate_with_unknown_error(
self: _LexiconAwareTestCase) -> None:
self.provider_mock.authenticate.side_effect = self.UNKNOWN_LOGIN_ERROR
self.assertRaises(errors.PluginError,
self.client.add_txt_record,
DOMAIN, self.record_name, self.record_content)
def test_add_txt_record_error_finding_domain(self: _LexiconAwareTestCase):
def test_add_txt_record_error_finding_domain(self: _LexiconAwareTestCase) -> None:
self.provider_mock.authenticate.side_effect = self.GENERIC_ERROR
self.assertRaises(errors.PluginError,
self.client.add_txt_record,
DOMAIN, self.record_name, self.record_content)
def test_add_txt_record_error_adding_record(self: _LexiconAwareTestCase):
def test_add_txt_record_error_adding_record(self: _LexiconAwareTestCase) -> None:
self.provider_mock.create_record.side_effect = self.GENERIC_ERROR
self.assertRaises(errors.PluginError,
self.client.add_txt_record,
DOMAIN, self.record_name, self.record_content)
def test_del_txt_record(self: _LexiconAwareTestCase):
def test_del_txt_record(self: _LexiconAwareTestCase) -> None:
self.client.del_txt_record(DOMAIN, self.record_name, self.record_content)
self.provider_mock.delete_record.assert_called_with(rtype='TXT',
name=self.record_name,
content=self.record_content)
def test_del_txt_record_fail_to_find_domain(self: _LexiconAwareTestCase):
def test_del_txt_record_fail_to_find_domain(self: _LexiconAwareTestCase) -> None:
self.provider_mock.authenticate.side_effect = [self.DOMAIN_NOT_FOUND,
self.DOMAIN_NOT_FOUND,
self.DOMAIN_NOT_FOUND, ]
self.client.del_txt_record(DOMAIN, self.record_name, self.record_content)
def test_del_txt_record_fail_to_authenticate(self: _LexiconAwareTestCase):
def test_del_txt_record_fail_to_authenticate(self: _LexiconAwareTestCase) -> None:
self.provider_mock.authenticate.side_effect = self.LOGIN_ERROR
self.client.del_txt_record(DOMAIN, self.record_name, self.record_content)
def test_del_txt_record_fail_to_authenticate_with_unknown_error(self: _LexiconAwareTestCase):
def test_del_txt_record_fail_to_authenticate_with_unknown_error(
self: _LexiconAwareTestCase) -> None:
self.provider_mock.authenticate.side_effect = self.UNKNOWN_LOGIN_ERROR
self.client.del_txt_record(DOMAIN, self.record_name, self.record_content)
def test_del_txt_record_error_finding_domain(self: _LexiconAwareTestCase):
def test_del_txt_record_error_finding_domain(self: _LexiconAwareTestCase) -> None:
self.provider_mock.authenticate.side_effect = self.GENERIC_ERROR
self.client.del_txt_record(DOMAIN, self.record_name, self.record_content)
def test_del_txt_record_error_deleting_record(self: _LexiconAwareTestCase):
def test_del_txt_record_error_deleting_record(self: _LexiconAwareTestCase) -> None:
self.provider_mock.delete_record.side_effect = self.GENERIC_ERROR
self.client.del_txt_record(DOMAIN, self.record_name, self.record_content)

View file

@ -1,9 +1,15 @@
"""New interface style Certbot enhancements"""
import abc
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generator
from typing import Iterable
from typing import List
from typing import Optional
from certbot import configuration
from certbot import interfaces
from certbot._internal import constants
ENHANCEMENTS = ["redirect", "ensure-http-header", "ocsp-stapling"]
@ -17,7 +23,9 @@ List of expected options parameters:
"""
def enabled_enhancements(config):
def enabled_enhancements(
config: configuration.NamespaceConfig) -> Generator[Dict[str, Any], None, None]:
"""
Generator to yield the enabled new style enhancements.
@ -28,7 +36,8 @@ def enabled_enhancements(config):
if getattr(config, enh["cli_dest"]):
yield enh
def are_requested(config):
def are_requested(config: configuration.NamespaceConfig) -> bool:
"""
Checks if one or more of the requested enhancements are those of the new
enhancement interfaces.
@ -38,7 +47,9 @@ def are_requested(config):
"""
return any(enabled_enhancements(config))
def are_supported(config, installer):
def are_supported(config: configuration.NamespaceConfig,
installer: Optional[interfaces.Installer]) -> bool:
"""
Checks that all of the requested enhancements are supported by the
installer.
@ -57,7 +68,10 @@ def are_supported(config, installer):
return False
return True
def enable(lineage, domains, installer, config):
def enable(lineage: Optional[interfaces.RenewableCert], domains: Iterable[str],
installer: Optional[interfaces.Installer],
config: configuration.NamespaceConfig) -> None:
"""
Run enable method for each requested enhancement that is supported.
@ -73,10 +87,12 @@ def enable(lineage, domains, installer, config):
:param config: Configuration.
:type config: certbot.configuration.NamespaceConfig
"""
for enh in enabled_enhancements(config):
getattr(installer, enh["enable_function"])(lineage, domains)
if installer:
for enh in enabled_enhancements(config):
getattr(installer, enh["enable_function"])(lineage, domains)
def populate_cli(add):
def populate_cli(add: Callable[..., None]) -> None:
"""
Populates the command line flags for certbot._internal.cli.HelpfulParser
@ -116,7 +132,7 @@ class AutoHSTSEnhancement(object, metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def update_autohsts(self, lineage, *args, **kwargs):
def update_autohsts(self, lineage: interfaces.RenewableCert, *args: Any, **kwargs: Any) -> None:
"""
Gets called for each lineage every time Certbot is run with 'renew' verb.
Implementation of this method should increase the max-age value.
@ -130,7 +146,7 @@ class AutoHSTSEnhancement(object, metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def deploy_autohsts(self, lineage, *args, **kwargs):
def deploy_autohsts(self, lineage: interfaces.RenewableCert, *args: Any, **kwargs: Any) -> None:
"""
Gets called for a lineage when its certificate is successfully renewed.
Long max-age value should be set in implementation of this method.
@ -140,7 +156,8 @@ class AutoHSTSEnhancement(object, metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def enable_autohsts(self, lineage, domains, *args, **kwargs):
def enable_autohsts(self, lineage: Optional[interfaces.RenewableCert], domains: Iterable[str],
*args: Any, **kwargs: Any) -> None:
"""
Enables the AutoHSTS enhancement, installing
Strict-Transport-Security header with a low initial value to be increased
@ -153,6 +170,7 @@ class AutoHSTSEnhancement(object, metaclass=abc.ABCMeta):
:type domains: `list` of `str`
"""
# This is used to configure internal new style enhancements in Certbot. These
# enhancement interfaces need to be defined in this file. Please do not modify
# this list from plugin code.

View file

@ -4,6 +4,7 @@ import logging
from typing import Any
from typing import Dict
from certbot import configuration
from certbot import errors
from certbot.compat import filesystem
from certbot.compat import os
@ -14,7 +15,7 @@ logger = logging.getLogger(__name__)
class PluginStorage:
"""Class implementing storage functionality for plugins"""
def __init__(self, config, classkey):
def __init__(self, config: configuration.NamespaceConfig, classkey: str) -> None:
"""Initializes PluginStorage object storing required configuration
options.
@ -29,7 +30,7 @@ class PluginStorage:
self._data: Dict
self._storagepath: str
def _initialize_storage(self):
def _initialize_storage(self) -> None:
"""Initializes PluginStorage data and reads current state from the disk
if the storage json exists."""
@ -37,7 +38,7 @@ class PluginStorage:
self._load()
self._initialized = True
def _load(self):
def _load(self) -> None:
"""Reads PluginStorage content from the disk to a dict structure
:raises .errors.PluginStorageError: when unable to open or read the file
@ -67,7 +68,7 @@ class PluginStorage:
raise errors.PluginStorageError(errmsg)
self._data = data
def save(self):
def save(self) -> None:
"""Saves PluginStorage content to disk
:raises .errors.PluginStorageError: when unable to serialize the data
@ -97,7 +98,7 @@ class PluginStorage:
logger.error(errmsg)
raise errors.PluginStorageError(errmsg)
def put(self, key, value):
def put(self, key: str, value: Any) -> None:
"""Put configuration value to PluginStorage
:param str key: Key to store the value to
@ -110,7 +111,7 @@ class PluginStorage:
self._data[self._classkey] = {}
self._data[self._classkey][key] = value
def fetch(self, key):
def fetch(self, key: str) -> Any:
"""Get configuration value from PluginStorage
:param str key: Key to get value from the storage

View file

@ -1,5 +1,6 @@
"""Plugin utilities."""
import logging
from typing import List
from certbot import util
from certbot.compat import os
@ -8,7 +9,7 @@ from certbot.compat.misc import STANDARD_BINARY_DIRS
logger = logging.getLogger(__name__)
def get_prefixes(path):
def get_prefixes(path: str) -> List[str]:
"""Retrieves all possible path prefixes of a path, in descending order
of length. For instance:
@ -21,7 +22,7 @@ def get_prefixes(path):
:rtype: `list` of `str`
"""
prefix = os.path.normpath(path)
prefixes = []
prefixes: List[str] = []
while prefix:
prefixes.append(prefix)
prefix, _ = os.path.split(prefix)
@ -31,7 +32,7 @@ def get_prefixes(path):
return prefixes
def path_surgery(cmd):
def path_surgery(cmd: str) -> bool:
"""Attempt to perform PATH surgery to find cmd
Mitigates https://github.com/certbot/certbot/issues/1833

View file

@ -123,7 +123,7 @@ def run_script(params: List[str], log: Callable[[str], None]=logger.error) -> Tu
return proc.stdout, proc.stderr
def exe_exists(exe: Optional[str]) -> bool:
def exe_exists(exe: str) -> bool:
"""Determine whether path/name refers to an executable.
:param str exe: Executable path or name
@ -132,9 +132,6 @@ def exe_exists(exe: Optional[str]) -> bool:
:rtype: bool
"""
if exe is None:
return False
path, _ = os.path.split(exe)
if path:
return filesystem.is_executable(exe)

View file

@ -474,7 +474,7 @@ class ChooseValuesTest(unittest.TestCase):
result = self._call(items, None)
self.assertEqual(result, [items[2]])
self.assertIs(mock_util().checklist.called, True)
self.assertIsNone(mock_util().checklist.call_args[0][0])
self.assertEqual(mock_util().checklist.call_args[0][0], "")
@test_util.patch_display_util()
def test_choose_names_success_question(self, mock_util):