diff --git a/letsencrypt/auth_handler.py b/letsencrypt/auth_handler.py index 6498a5c19..a285825dc 100644 --- a/letsencrypt/auth_handler.py +++ b/letsencrypt/auth_handler.py @@ -11,6 +11,7 @@ from acme import messages from letsencrypt import achallenges from letsencrypt import constants from letsencrypt import errors +from letsencrypt import error_handler from letsencrypt import interfaces @@ -106,17 +107,12 @@ class AuthHandler(object): """Get Responses for challenges from authenticators.""" cont_resp = [] dv_resp = [] - try: + logger.info("Attempting to set up challenges.") + with error_handler.ErrorHandler(self._cleanup_challenges): if self.cont_c: cont_resp = self.cont_auth.perform(self.cont_c) if self.dv_c: dv_resp = self.dv_auth.perform(self.dv_c) - # This will catch both specific types of errors. - except errors.AuthorizationError: - logger.critical("Failure in setting up challenges.") - logger.info("Attempting to clean up outstanding challenges...") - self._cleanup_challenges() - raise assert len(cont_resp) == len(self.cont_c) assert len(dv_resp) == len(self.dv_c) diff --git a/letsencrypt/client.py b/letsencrypt/client.py index 60eaea5a1..3f1f4900b 100644 --- a/letsencrypt/client.py +++ b/letsencrypt/client.py @@ -18,6 +18,7 @@ from letsencrypt import constants from letsencrypt import continuity_auth from letsencrypt import crypto_util from letsencrypt import errors +from letsencrypt import error_handler from letsencrypt import interfaces from letsencrypt import le_util from letsencrypt import reverter @@ -364,16 +365,17 @@ class Client(object): chain_path = None if chain_path is None else os.path.abspath(chain_path) - for dom in domains: - # TODO: Provide a fullchain reference for installers like - # nginx that want it - self.installer.deploy_cert( - dom, os.path.abspath(cert_path), - os.path.abspath(privkey_path), chain_path) + with error_handler.ErrorHandler(self.installer.recovery_routine): + for dom in domains: + # TODO: Provide a fullchain reference for installers like + # nginx that want it + self.installer.deploy_cert( + dom, os.path.abspath(cert_path), + os.path.abspath(privkey_path), chain_path) - self.installer.save("Deployed Let's Encrypt Certificate") - # sites may have been enabled / final cleanup - self.installer.restart() + self.installer.save("Deployed Let's Encrypt Certificate") + # sites may have been enabled / final cleanup + self.installer.restart() def enhance_config(self, domains, redirect=None): """Enhance the configuration. @@ -399,6 +401,8 @@ class Client(object): if redirect is None: redirect = enhancements.ask("redirect") + # When support for more enhancements are added, the call to the + # plugin's `enhance` function should be wrapped by an ErrorHandler if redirect: self.redirect_to_ssl(domains) @@ -409,14 +413,13 @@ class Client(object): :type vhost: :class:`letsencrypt.interfaces.IInstaller` """ - for dom in domains: - try: + with error_handler.ErrorHandler(self.installer.recovery_routine): + for dom in domains: + logger.info("Attempting to perform redirect for %s", dom) self.installer.enhance(dom, "redirect") - except errors.PluginError: - logger.warn("Unable to perform redirect for %s", dom) - self.installer.save("Add Redirects") - self.installer.restart() + self.installer.save("Add Redirects") + self.installer.restart() def validate_key_csr(privkey, csr=None): diff --git a/letsencrypt/error_handler.py b/letsencrypt/error_handler.py index b82f49b5a..3fc948b54 100644 --- a/letsencrypt/error_handler.py +++ b/letsencrypt/error_handler.py @@ -11,8 +11,10 @@ _SIGNALS = ([signal.SIGTERM] if os.name == "nt" else class ErrorHandler(object): """Registers and calls cleanup functions in case of an error.""" def __init__(self, func=None): - self.funcs = [func] if func else [] + self.funcs = [] self.prev_handlers = {} + if func: + self.register(func) def __enter__(self): self.set_signal_handlers() diff --git a/letsencrypt/interfaces.py b/letsencrypt/interfaces.py index af145ab0a..a0d2eb97f 100644 --- a/letsencrypt/interfaces.py +++ b/letsencrypt/interfaces.py @@ -321,7 +321,7 @@ class IInstaller(IPlugin): """ - def recovery_routine(self): + def recovery_routine(): """Revert configuration to most recent finalized checkpoint. Remove all changes (temporary and permanent) that have not been diff --git a/letsencrypt/tests/client_test.py b/letsencrypt/tests/client_test.py index 93fdf2cd3..0131d3c93 100644 --- a/letsencrypt/tests/client_test.py +++ b/letsencrypt/tests/client_test.py @@ -178,6 +178,39 @@ class ClientTest(unittest.TestCase): shutil.rmtree(tmp_path) + def test_deploy_certificate(self): + self.assertRaises(errors.Error, self.client.deploy_certificate, + ["foo.bar"], "key", "cert", "chain") + + installer = mock.MagicMock() + self.client.installer = installer + + self.client.deploy_certificate(["foo.bar"], "key", "cert", "chain") + installer.deploy_cert.assert_called_once_with( + "foo.bar", os.path.abspath("cert"), + os.path.abspath("key"), os.path.abspath("chain")) + self.assertTrue(installer.save.call_count == 1) + installer.restart.assert_called_once_with() + + @mock.patch("letsencrypt.client.enhancements") + def test_enhance_config(self, mock_enhancements): + self.assertRaises(errors.Error, + self.client.enhance_config, ["foo.bar"]) + + mock_enhancements.ask.return_value = True + installer = mock.MagicMock() + self.client.installer = installer + + self.client.enhance_config(["foo.bar"]) + installer.enhance.assert_called_once_with("foo.bar", "redirect") + self.assertTrue(installer.save.call_count == 1) + installer.restart.assert_called_once_with() + + installer.enhance.side_effect = errors.PluginError + self.assertRaises(errors.PluginError, + self.client.enhance_config, ["foo.bar"], True) + installer.recovery_routine.assert_called_once_with() + class RollbackTest(unittest.TestCase): """Tests for letsencrypt.client.rollback.""" diff --git a/letsencrypt/tests/error_handler_test.py b/letsencrypt/tests/error_handler_test.py index 6c6d02ec3..6927b32a0 100644 --- a/letsencrypt/tests/error_handler_test.py +++ b/letsencrypt/tests/error_handler_test.py @@ -1,25 +1,46 @@ """Tests for letsencrypt.error_handler.""" +import signal import unittest import mock +from letsencrypt import error_handler + class ErrorHandlerTest(unittest.TestCase): """Tests for letsencrypt.error_handler.""" def setUp(self): - from letsencrypt import error_handler self.init_func = mock.MagicMock() - self.error_handler = error_handler.ErrorHandler(self.init_func) + self.handler = error_handler.ErrorHandler(self.init_func) def test_context_manager(self): try: - with self.error_handler: + with self.handler: raise ValueError except ValueError: pass self.init_func.assert_called_once_with() + @mock.patch('letsencrypt.error_handler.os') + @mock.patch('letsencrypt.error_handler.signal') + def test_signal_handler(self, mock_signal, mock_os): + # pylint: disable=protected-access + mock_signal.getsignal.return_value = signal.SIG_DFL + self.handler.set_signal_handlers() + signal_handler = self.handler._signal_handler + for signum in error_handler._SIGNALS: + mock_signal.signal.assert_any_call(signum, signal_handler) + + signum = error_handler._SIGNALS[0] + signal_handler(signum, None) + self.init_func.assert_called_once_with() + mock_os.kill.assert_called_once_with(mock_os.getpid(), signum) + + self.handler.reset_signal_handlers() + for signum in error_handler._SIGNALS: + mock_signal.signal.assert_any_call(signum, signal.SIG_DFL) + if __name__ == "__main__": unittest.main() # pragma: no cover