mirror of
https://github.com/certbot/certbot.git
synced 2026-06-08 08:12:15 -04:00
Incorporated Kuba's feedback and better defined corner cases
This commit is contained in:
parent
31e9519ef5
commit
fd0c51e48a
5 changed files with 77 additions and 22 deletions
|
|
@ -107,12 +107,16 @@ class AuthHandler(object):
|
|||
"""Get Responses for challenges from authenticators."""
|
||||
cont_resp = []
|
||||
dv_resp = []
|
||||
logger.info("Attempting to set up challenges.")
|
||||
with error_handler.ErrorHandler(self._cleanup_challenges):
|
||||
if self.cont_c:
|
||||
cont_resp = self.cont_auth.perform(self.cont_c)
|
||||
if self.dv_c:
|
||||
dv_resp = self.dv_auth.perform(self.dv_c)
|
||||
try:
|
||||
if self.cont_c:
|
||||
cont_resp = self.cont_auth.perform(self.cont_c)
|
||||
if self.dv_c:
|
||||
dv_resp = self.dv_auth.perform(self.dv_c)
|
||||
except errors.AuthorizationError:
|
||||
logger.critical("Failure in setting up challenges.")
|
||||
logger.info("Attempting to clean up outstanding challenges...")
|
||||
raise
|
||||
|
||||
assert len(cont_resp) == len(self.cont_c)
|
||||
assert len(dv_resp) == len(self.dv_c)
|
||||
|
|
|
|||
|
|
@ -415,8 +415,11 @@ class Client(object):
|
|||
"""
|
||||
with error_handler.ErrorHandler(self.installer.recovery_routine):
|
||||
for dom in domains:
|
||||
logger.info("Attempting to perform redirect for %s", dom)
|
||||
self.installer.enhance(dom, "redirect")
|
||||
try:
|
||||
self.installer.enhance(dom, "redirect")
|
||||
except errors.PluginError:
|
||||
logger.warn("Unable to perform redirect for %s", dom)
|
||||
raise
|
||||
|
||||
self.installer.save("Add Redirects")
|
||||
self.installer.restart()
|
||||
|
|
|
|||
|
|
@ -1,26 +1,58 @@
|
|||
"""Registers and calls cleanup functions in case of an error."""
|
||||
"""Registers functions to be called if an exception or signal occurs."""
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import traceback
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# _SIGNALS stores the signals that will be handled by the ErrorHandler. These
|
||||
# signals were chosen as their default handler terminates the process and could
|
||||
# potentially occur from inside Python. Signals such as SIGILL were not
|
||||
# included as they could be a sign of something devious and we should terminate
|
||||
# immediately.
|
||||
_SIGNALS = ([signal.SIGTERM] if os.name == "nt" else
|
||||
[signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT,
|
||||
signal.SIGXCPU, signal.SIGXFSZ, signal.SIGPWR])
|
||||
|
||||
|
||||
class ErrorHandler(object):
|
||||
"""Registers and calls cleanup functions in case of an error."""
|
||||
"""Registers functions to be called if an exception or signal occurs.
|
||||
|
||||
This class allows you to register functions that will be called when
|
||||
an exception or signal is encountered. The class works best as a
|
||||
context manager. For example:
|
||||
|
||||
with ErrorHandler(cleanup_func):
|
||||
do_something()
|
||||
|
||||
If an exception is raised out of do_something, cleanup_func will be
|
||||
called. The exception is not caught by the ErrorHandler. Similarly,
|
||||
if a signal is encountered, cleanup_func is called followed by the
|
||||
previously registered signal handler.
|
||||
|
||||
Every registered function is attempted to be run to completion
|
||||
exactly once. If a registered function raises an exception, it is
|
||||
logged and the next function is called. If a (different) handled
|
||||
signal occurs while calling a registered function, it is attempted
|
||||
to be called again by the next signal handler.
|
||||
|
||||
"""
|
||||
def __init__(self, func=None):
|
||||
self.funcs = []
|
||||
self.prev_handlers = {}
|
||||
if func:
|
||||
if func is not None:
|
||||
self.register(func)
|
||||
|
||||
def __enter__(self):
|
||||
self.set_signal_handlers()
|
||||
|
||||
def __exit__(self, exec_type, exec_value, traceback):
|
||||
def __exit__(self, exec_type, exec_value, trace):
|
||||
if exec_value is not None:
|
||||
logger.debug("Encountered exception:\n%s", "".join(
|
||||
traceback.format_exception(exec_type, exec_value, trace)))
|
||||
self.call_registered()
|
||||
self.reset_signal_handlers()
|
||||
|
||||
|
|
@ -29,9 +61,15 @@ class ErrorHandler(object):
|
|||
self.funcs.append(func)
|
||||
|
||||
def call_registered(self):
|
||||
"""Calls all functions in the order they were registered."""
|
||||
for func in self.funcs:
|
||||
func()
|
||||
"""Calls all registered functions"""
|
||||
logger.debug("Calling registered functions")
|
||||
while self.funcs:
|
||||
try:
|
||||
self.funcs[-1]()
|
||||
except Exception as error: # pylint: disable=broad-except
|
||||
logger.error("Encountered exception during recovery")
|
||||
logger.exception(error)
|
||||
self.funcs.pop()
|
||||
|
||||
def set_signal_handlers(self):
|
||||
"""Sets signal handlers for signals in _SIGNALS."""
|
||||
|
|
@ -48,12 +86,13 @@ class ErrorHandler(object):
|
|||
signal.signal(signum, self.prev_handlers[signum])
|
||||
self.prev_handlers.clear()
|
||||
|
||||
def _signal_handler(self, signum, _):
|
||||
def _signal_handler(self, signum, unused_frame):
|
||||
"""Calls registered functions and the previous signal handler.
|
||||
|
||||
:param int signum: number of current signal
|
||||
|
||||
"""
|
||||
logger.debug("Singal %s encountered", signum)
|
||||
self.call_registered()
|
||||
signal.signal(signum, self.prev_handlers[signum])
|
||||
os.kill(os.getpid(), signum)
|
||||
|
|
|
|||
|
|
@ -189,7 +189,7 @@ class ClientTest(unittest.TestCase):
|
|||
installer.deploy_cert.assert_called_once_with(
|
||||
"foo.bar", os.path.abspath("cert"),
|
||||
os.path.abspath("key"), os.path.abspath("chain"))
|
||||
self.assertTrue(installer.save.call_count == 1)
|
||||
self.assertEqual(installer.save.call_count, 1)
|
||||
installer.restart.assert_called_once_with()
|
||||
|
||||
@mock.patch("letsencrypt.client.enhancements")
|
||||
|
|
@ -203,7 +203,7 @@ class ClientTest(unittest.TestCase):
|
|||
|
||||
self.client.enhance_config(["foo.bar"])
|
||||
installer.enhance.assert_called_once_with("foo.bar", "redirect")
|
||||
self.assertTrue(installer.save.call_count == 1)
|
||||
self.assertEqual(installer.save.call_count, 1)
|
||||
installer.restart.assert_called_once_with()
|
||||
|
||||
installer.enhance.side_effect = errors.PluginError
|
||||
|
|
|
|||
|
|
@ -4,15 +4,17 @@ import unittest
|
|||
|
||||
import mock
|
||||
|
||||
from letsencrypt import error_handler
|
||||
|
||||
|
||||
class ErrorHandlerTest(unittest.TestCase):
|
||||
"""Tests for letsencrypt.error_handler."""
|
||||
|
||||
def setUp(self):
|
||||
from letsencrypt import error_handler
|
||||
|
||||
self.init_func = mock.MagicMock()
|
||||
self.handler = error_handler.ErrorHandler(self.init_func)
|
||||
# pylint: disable=protected-access
|
||||
self.signals = error_handler._SIGNALS
|
||||
|
||||
def test_context_manager(self):
|
||||
try:
|
||||
|
|
@ -29,18 +31,25 @@ class ErrorHandlerTest(unittest.TestCase):
|
|||
mock_signal.getsignal.return_value = signal.SIG_DFL
|
||||
self.handler.set_signal_handlers()
|
||||
signal_handler = self.handler._signal_handler
|
||||
for signum in error_handler._SIGNALS:
|
||||
for signum in self.signals:
|
||||
mock_signal.signal.assert_any_call(signum, signal_handler)
|
||||
|
||||
signum = error_handler._SIGNALS[0]
|
||||
signum = self.signals[0]
|
||||
signal_handler(signum, None)
|
||||
self.init_func.assert_called_once_with()
|
||||
mock_os.kill.assert_called_once_with(mock_os.getpid(), signum)
|
||||
|
||||
self.handler.reset_signal_handlers()
|
||||
for signum in error_handler._SIGNALS:
|
||||
for signum in self.signals:
|
||||
mock_signal.signal.assert_any_call(signum, signal.SIG_DFL)
|
||||
|
||||
def test_bad_recovery(self):
|
||||
bad_func = mock.MagicMock(side_effect=[ValueError])
|
||||
self.handler.register(bad_func)
|
||||
self.handler.call_registered()
|
||||
self.init_func.assert_called_once_with()
|
||||
bad_func.assert_called_once_with()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main() # pragma: no cover
|
||||
|
|
|
|||
Loading…
Reference in a new issue