From fd0c51e48afef3fb618d5027d4420a921c00f9a3 Mon Sep 17 00:00:00 2001 From: Brad Warren Date: Thu, 24 Sep 2015 16:23:40 -0700 Subject: [PATCH] Incorporated Kuba's feedback and better defined corner cases --- letsencrypt/auth_handler.py | 14 ++++--- letsencrypt/client.py | 7 +++- letsencrypt/error_handler.py | 55 +++++++++++++++++++++---- letsencrypt/tests/client_test.py | 4 +- letsencrypt/tests/error_handler_test.py | 19 ++++++--- 5 files changed, 77 insertions(+), 22 deletions(-) diff --git a/letsencrypt/auth_handler.py b/letsencrypt/auth_handler.py index a285825dc..68aed510a 100644 --- a/letsencrypt/auth_handler.py +++ b/letsencrypt/auth_handler.py @@ -107,12 +107,16 @@ class AuthHandler(object): """Get Responses for challenges from authenticators.""" cont_resp = [] dv_resp = [] - logger.info("Attempting to set up challenges.") with error_handler.ErrorHandler(self._cleanup_challenges): - if self.cont_c: - cont_resp = self.cont_auth.perform(self.cont_c) - if self.dv_c: - dv_resp = self.dv_auth.perform(self.dv_c) + try: + if self.cont_c: + cont_resp = self.cont_auth.perform(self.cont_c) + if self.dv_c: + dv_resp = self.dv_auth.perform(self.dv_c) + except errors.AuthorizationError: + logger.critical("Failure in setting up challenges.") + logger.info("Attempting to clean up outstanding challenges...") + raise assert len(cont_resp) == len(self.cont_c) assert len(dv_resp) == len(self.dv_c) diff --git a/letsencrypt/client.py b/letsencrypt/client.py index 3f1f4900b..56d9b1fda 100644 --- a/letsencrypt/client.py +++ b/letsencrypt/client.py @@ -415,8 +415,11 @@ class Client(object): """ with error_handler.ErrorHandler(self.installer.recovery_routine): for dom in domains: - logger.info("Attempting to perform redirect for %s", dom) - self.installer.enhance(dom, "redirect") + try: + self.installer.enhance(dom, "redirect") + except errors.PluginError: + logger.warn("Unable to perform redirect for %s", dom) + raise self.installer.save("Add Redirects") self.installer.restart() diff --git a/letsencrypt/error_handler.py b/letsencrypt/error_handler.py index 3fc948b54..fedb66c0e 100644 --- a/letsencrypt/error_handler.py +++ b/letsencrypt/error_handler.py @@ -1,26 +1,58 @@ -"""Registers and calls cleanup functions in case of an error.""" +"""Registers functions to be called if an exception or signal occurs.""" +import logging import os import signal +import traceback +logger = logging.getLogger(__name__) + + +# _SIGNALS stores the signals that will be handled by the ErrorHandler. These +# signals were chosen as their default handler terminates the process and could +# potentially occur from inside Python. Signals such as SIGILL were not +# included as they could be a sign of something devious and we should terminate +# immediately. _SIGNALS = ([signal.SIGTERM] if os.name == "nt" else [signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT, signal.SIGXCPU, signal.SIGXFSZ, signal.SIGPWR]) class ErrorHandler(object): - """Registers and calls cleanup functions in case of an error.""" + """Registers functions to be called if an exception or signal occurs. + + This class allows you to register functions that will be called when + an exception or signal is encountered. The class works best as a + context manager. For example: + + with ErrorHandler(cleanup_func): + do_something() + + If an exception is raised out of do_something, cleanup_func will be + called. The exception is not caught by the ErrorHandler. Similarly, + if a signal is encountered, cleanup_func is called followed by the + previously registered signal handler. + + Every registered function is attempted to be run to completion + exactly once. If a registered function raises an exception, it is + logged and the next function is called. If a (different) handled + signal occurs while calling a registered function, it is attempted + to be called again by the next signal handler. + + """ def __init__(self, func=None): self.funcs = [] self.prev_handlers = {} - if func: + if func is not None: self.register(func) def __enter__(self): self.set_signal_handlers() - def __exit__(self, exec_type, exec_value, traceback): + def __exit__(self, exec_type, exec_value, trace): if exec_value is not None: + logger.debug("Encountered exception:\n%s", "".join( + traceback.format_exception(exec_type, exec_value, trace))) self.call_registered() self.reset_signal_handlers() @@ -29,9 +61,15 @@ class ErrorHandler(object): self.funcs.append(func) def call_registered(self): - """Calls all functions in the order they were registered.""" - for func in self.funcs: - func() + """Calls all registered functions""" + logger.debug("Calling registered functions") + while self.funcs: + try: + self.funcs[-1]() + except Exception as error: # pylint: disable=broad-except + logger.error("Encountered exception during recovery") + logger.exception(error) + self.funcs.pop() def set_signal_handlers(self): """Sets signal handlers for signals in _SIGNALS.""" @@ -48,12 +86,13 @@ class ErrorHandler(object): signal.signal(signum, self.prev_handlers[signum]) self.prev_handlers.clear() - def _signal_handler(self, signum, _): + def _signal_handler(self, signum, unused_frame): """Calls registered functions and the previous signal handler. :param int signum: number of current signal """ + logger.debug("Singal %s encountered", signum) self.call_registered() signal.signal(signum, self.prev_handlers[signum]) os.kill(os.getpid(), signum) diff --git a/letsencrypt/tests/client_test.py b/letsencrypt/tests/client_test.py index 0131d3c93..83cd54226 100644 --- a/letsencrypt/tests/client_test.py +++ b/letsencrypt/tests/client_test.py @@ -189,7 +189,7 @@ class ClientTest(unittest.TestCase): installer.deploy_cert.assert_called_once_with( "foo.bar", os.path.abspath("cert"), os.path.abspath("key"), os.path.abspath("chain")) - self.assertTrue(installer.save.call_count == 1) + self.assertEqual(installer.save.call_count, 1) installer.restart.assert_called_once_with() @mock.patch("letsencrypt.client.enhancements") @@ -203,7 +203,7 @@ class ClientTest(unittest.TestCase): self.client.enhance_config(["foo.bar"]) installer.enhance.assert_called_once_with("foo.bar", "redirect") - self.assertTrue(installer.save.call_count == 1) + self.assertEqual(installer.save.call_count, 1) installer.restart.assert_called_once_with() installer.enhance.side_effect = errors.PluginError diff --git a/letsencrypt/tests/error_handler_test.py b/letsencrypt/tests/error_handler_test.py index 6927b32a0..66acac930 100644 --- a/letsencrypt/tests/error_handler_test.py +++ b/letsencrypt/tests/error_handler_test.py @@ -4,15 +4,17 @@ import unittest import mock -from letsencrypt import error_handler - class ErrorHandlerTest(unittest.TestCase): """Tests for letsencrypt.error_handler.""" def setUp(self): + from letsencrypt import error_handler + self.init_func = mock.MagicMock() self.handler = error_handler.ErrorHandler(self.init_func) + # pylint: disable=protected-access + self.signals = error_handler._SIGNALS def test_context_manager(self): try: @@ -29,18 +31,25 @@ class ErrorHandlerTest(unittest.TestCase): mock_signal.getsignal.return_value = signal.SIG_DFL self.handler.set_signal_handlers() signal_handler = self.handler._signal_handler - for signum in error_handler._SIGNALS: + for signum in self.signals: mock_signal.signal.assert_any_call(signum, signal_handler) - signum = error_handler._SIGNALS[0] + signum = self.signals[0] signal_handler(signum, None) self.init_func.assert_called_once_with() mock_os.kill.assert_called_once_with(mock_os.getpid(), signum) self.handler.reset_signal_handlers() - for signum in error_handler._SIGNALS: + for signum in self.signals: mock_signal.signal.assert_any_call(signum, signal.SIG_DFL) + def test_bad_recovery(self): + bad_func = mock.MagicMock(side_effect=[ValueError]) + self.handler.register(bad_func) + self.handler.call_registered() + self.init_func.assert_called_once_with() + bad_func.assert_called_once_with() + if __name__ == "__main__": unittest.main() # pragma: no cover