Incorporated Kuba's feedback and better defined corner cases

This commit is contained in:
Brad Warren 2015-09-24 16:23:40 -07:00
parent 31e9519ef5
commit fd0c51e48a
5 changed files with 77 additions and 22 deletions

View file

@ -107,12 +107,16 @@ class AuthHandler(object):
"""Get Responses for challenges from authenticators."""
cont_resp = []
dv_resp = []
logger.info("Attempting to set up challenges.")
with error_handler.ErrorHandler(self._cleanup_challenges):
if self.cont_c:
cont_resp = self.cont_auth.perform(self.cont_c)
if self.dv_c:
dv_resp = self.dv_auth.perform(self.dv_c)
try:
if self.cont_c:
cont_resp = self.cont_auth.perform(self.cont_c)
if self.dv_c:
dv_resp = self.dv_auth.perform(self.dv_c)
except errors.AuthorizationError:
logger.critical("Failure in setting up challenges.")
logger.info("Attempting to clean up outstanding challenges...")
raise
assert len(cont_resp) == len(self.cont_c)
assert len(dv_resp) == len(self.dv_c)

View file

@ -415,8 +415,11 @@ class Client(object):
"""
with error_handler.ErrorHandler(self.installer.recovery_routine):
for dom in domains:
logger.info("Attempting to perform redirect for %s", dom)
self.installer.enhance(dom, "redirect")
try:
self.installer.enhance(dom, "redirect")
except errors.PluginError:
logger.warn("Unable to perform redirect for %s", dom)
raise
self.installer.save("Add Redirects")
self.installer.restart()

View file

@ -1,26 +1,58 @@
"""Registers and calls cleanup functions in case of an error."""
"""Registers functions to be called if an exception or signal occurs."""
import logging
import os
import signal
import traceback
logger = logging.getLogger(__name__)
# _SIGNALS stores the signals that will be handled by the ErrorHandler. These
# signals were chosen as their default handler terminates the process and could
# potentially occur from inside Python. Signals such as SIGILL were not
# included as they could be a sign of something devious and we should terminate
# immediately.
_SIGNALS = ([signal.SIGTERM] if os.name == "nt" else
[signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT,
signal.SIGXCPU, signal.SIGXFSZ, signal.SIGPWR])
class ErrorHandler(object):
"""Registers and calls cleanup functions in case of an error."""
"""Registers functions to be called if an exception or signal occurs.
This class allows you to register functions that will be called when
an exception or signal is encountered. The class works best as a
context manager. For example:
with ErrorHandler(cleanup_func):
do_something()
If an exception is raised out of do_something, cleanup_func will be
called. The exception is not caught by the ErrorHandler. Similarly,
if a signal is encountered, cleanup_func is called followed by the
previously registered signal handler.
Every registered function is attempted to be run to completion
exactly once. If a registered function raises an exception, it is
logged and the next function is called. If a (different) handled
signal occurs while calling a registered function, it is attempted
to be called again by the next signal handler.
"""
def __init__(self, func=None):
self.funcs = []
self.prev_handlers = {}
if func:
if func is not None:
self.register(func)
def __enter__(self):
self.set_signal_handlers()
def __exit__(self, exec_type, exec_value, traceback):
def __exit__(self, exec_type, exec_value, trace):
if exec_value is not None:
logger.debug("Encountered exception:\n%s", "".join(
traceback.format_exception(exec_type, exec_value, trace)))
self.call_registered()
self.reset_signal_handlers()
@ -29,9 +61,15 @@ class ErrorHandler(object):
self.funcs.append(func)
def call_registered(self):
"""Calls all functions in the order they were registered."""
for func in self.funcs:
func()
"""Calls all registered functions"""
logger.debug("Calling registered functions")
while self.funcs:
try:
self.funcs[-1]()
except Exception as error: # pylint: disable=broad-except
logger.error("Encountered exception during recovery")
logger.exception(error)
self.funcs.pop()
def set_signal_handlers(self):
"""Sets signal handlers for signals in _SIGNALS."""
@ -48,12 +86,13 @@ class ErrorHandler(object):
signal.signal(signum, self.prev_handlers[signum])
self.prev_handlers.clear()
def _signal_handler(self, signum, _):
def _signal_handler(self, signum, unused_frame):
"""Calls registered functions and the previous signal handler.
:param int signum: number of current signal
"""
logger.debug("Singal %s encountered", signum)
self.call_registered()
signal.signal(signum, self.prev_handlers[signum])
os.kill(os.getpid(), signum)

View file

@ -189,7 +189,7 @@ class ClientTest(unittest.TestCase):
installer.deploy_cert.assert_called_once_with(
"foo.bar", os.path.abspath("cert"),
os.path.abspath("key"), os.path.abspath("chain"))
self.assertTrue(installer.save.call_count == 1)
self.assertEqual(installer.save.call_count, 1)
installer.restart.assert_called_once_with()
@mock.patch("letsencrypt.client.enhancements")
@ -203,7 +203,7 @@ class ClientTest(unittest.TestCase):
self.client.enhance_config(["foo.bar"])
installer.enhance.assert_called_once_with("foo.bar", "redirect")
self.assertTrue(installer.save.call_count == 1)
self.assertEqual(installer.save.call_count, 1)
installer.restart.assert_called_once_with()
installer.enhance.side_effect = errors.PluginError

View file

@ -4,15 +4,17 @@ import unittest
import mock
from letsencrypt import error_handler
class ErrorHandlerTest(unittest.TestCase):
"""Tests for letsencrypt.error_handler."""
def setUp(self):
from letsencrypt import error_handler
self.init_func = mock.MagicMock()
self.handler = error_handler.ErrorHandler(self.init_func)
# pylint: disable=protected-access
self.signals = error_handler._SIGNALS
def test_context_manager(self):
try:
@ -29,18 +31,25 @@ class ErrorHandlerTest(unittest.TestCase):
mock_signal.getsignal.return_value = signal.SIG_DFL
self.handler.set_signal_handlers()
signal_handler = self.handler._signal_handler
for signum in error_handler._SIGNALS:
for signum in self.signals:
mock_signal.signal.assert_any_call(signum, signal_handler)
signum = error_handler._SIGNALS[0]
signum = self.signals[0]
signal_handler(signum, None)
self.init_func.assert_called_once_with()
mock_os.kill.assert_called_once_with(mock_os.getpid(), signum)
self.handler.reset_signal_handlers()
for signum in error_handler._SIGNALS:
for signum in self.signals:
mock_signal.signal.assert_any_call(signum, signal.SIG_DFL)
def test_bad_recovery(self):
bad_func = mock.MagicMock(side_effect=[ValueError])
self.handler.register(bad_func)
self.handler.call_registered()
self.init_func.assert_called_once_with()
bad_func.assert_called_once_with()
if __name__ == "__main__":
unittest.main() # pragma: no cover