diff --git a/certbot/auth_handler.py b/certbot/auth_handler.py index 68389d1f8..9d7c75f57 100644 --- a/certbot/auth_handler.py +++ b/certbot/auth_handler.py @@ -69,14 +69,15 @@ class AuthHandler(object): # While there are still challenges remaining... while self._has_challenges(aauthzrs): - resp = self._solve_challenges(aauthzrs) - logger.info("Waiting for verification...") - if config.debug_challenges: - notify('Challenges loaded. Press continue to submit to CA. ' - 'Pass "-v" for more info about challenges.', pause=True) + with error_handler.ExitHandler(self._cleanup_challenges, aauthzrs): + resp = self._solve_challenges(aauthzrs) + logger.info("Waiting for verification...") + if config.debug_challenges: + notify('Challenges loaded. Press continue to submit to CA. ' + 'Pass "-v" for more info about challenges.', pause=True) - # Send all Responses - this modifies achalls - self._respond(aauthzrs, resp, best_effort) + # Send all Responses - this modifies achalls + self._respond(aauthzrs, resp, best_effort) # Just make sure all decisions are complete. self.verify_authzr_complete(aauthzrs) @@ -118,14 +119,13 @@ class AuthHandler(object): """Get Responses for challenges from authenticators.""" resp = [] all_achalls = self._get_all_achalls(aauthzrs) - with error_handler.ErrorHandler(self._cleanup_challenges, aauthzrs, all_achalls): - try: - if all_achalls: - resp = self.auth.perform(all_achalls) - except errors.AuthorizationError: - logger.critical("Failure in setting up challenges.") - logger.info("Attempting to clean up outstanding challenges...") - raise + try: + if all_achalls: + resp = self.auth.perform(all_achalls) + except errors.AuthorizationError: + logger.critical("Failure in setting up challenges.") + logger.info("Attempting to clean up outstanding challenges...") + raise assert len(resp) == len(all_achalls) @@ -147,13 +147,10 @@ class AuthHandler(object): """ # TODO: chall_update is a dirty hack to get around acme-spec #105 chall_update = dict() - active_achalls = self._send_responses(aauthzrs, resp, chall_update) + self._send_responses(aauthzrs, resp, chall_update) # Check for updated status... - try: - self._poll_challenges(aauthzrs, chall_update, best_effort) - finally: - self._cleanup_challenges(aauthzrs, active_achalls) + self._poll_challenges(aauthzrs, chall_update, best_effort) def _send_responses(self, aauthzrs, resps, chall_update): """Send responses and make sure errors are handled. @@ -294,7 +291,7 @@ class AuthHandler(object): chall_prefs.extend(plugin_pref) return chall_prefs - def _cleanup_challenges(self, aauthzrs, achalls): + def _cleanup_challenges(self, aauthzrs, achalls=None): """Cleanup challenges. :param aauthzrs: authorizations and their selected annotated @@ -305,7 +302,8 @@ class AuthHandler(object): """ logger.info("Cleaning up challenges") - + if achalls is None: + achalls = self._get_all_achalls(aauthzrs) if achalls: self.auth.cleanup(achalls) for achall in achalls: diff --git a/certbot/error_handler.py b/certbot/error_handler.py index 842243f70..e2737711e 100644 --- a/certbot/error_handler.py +++ b/certbot/error_handler.py @@ -24,7 +24,6 @@ if os.name != "nt": if signal.getsignal(signal_code) != signal.SIG_IGN: _SIGNALS.append(signal_code) - class ErrorHandler(object): """Context manager for running code that must be cleaned up on failure. @@ -55,6 +54,7 @@ class ErrorHandler(object): """ def __init__(self, func=None, *args, **kwargs): + self.call_on_regular_exit = False self.body_executed = False self.funcs = [] self.prev_handlers = {} @@ -70,8 +70,11 @@ class ErrorHandler(object): self.body_executed = True retval = False # SystemExit is ignored to properly handle forks that don't exec - if exec_type in (None, SystemExit): + if exec_type is SystemExit: return retval + elif exec_type is None: + if not self.call_on_regular_exit: + return retval elif exec_type is errors.SignalExit: logger.debug("Encountered signals: %s", self.received_signals) retval = True @@ -136,3 +139,15 @@ class ErrorHandler(object): for signum in self.received_signals: logger.debug("Calling signal %s", signum) os.kill(os.getpid(), signum) + +class ExitHandler(ErrorHandler): + """Context manager for running code that must be cleaned up. + + Subclass of ErrorHandler, with the same usage and parameters. + In addition to cleaning up on all signals, also cleans up on + regular exit. + """ + def __init__(self, func=None, *args, **kwargs): + ErrorHandler.__init__(self, func, *args, **kwargs) + self.call_on_regular_exit = True + diff --git a/certbot/tests/auth_handler_test.py b/certbot/tests/auth_handler_test.py index a4ac9eb73..9a8a13498 100644 --- a/certbot/tests/auth_handler_test.py +++ b/certbot/tests/auth_handler_test.py @@ -289,6 +289,32 @@ class HandleAuthorizationsTest(unittest.TestCase): self.assertEqual( self.mock_auth.cleanup.call_args[0][0][0].typ, "tls-sni-01") + @mock.patch("certbot.auth_handler.AuthHandler._respond") + def test_respond_error(self, mock_respond): + authzrs = [gen_dom_authzr(domain="0", challs=acme_util.CHALLENGES)] + mock_order = mock.MagicMock(authorizations=authzrs) + mock_respond.side_effect = errors.AuthorizationError + + self.assertRaises( + errors.AuthorizationError, self.handler.handle_authorizations, mock_order) + self.assertEqual(self.mock_auth.cleanup.call_count, 1) + self.assertEqual( + self.mock_auth.cleanup.call_args[0][0][0].typ, "tls-sni-01") + + @mock.patch("certbot.auth_handler.AuthHandler._poll_challenges") + @mock.patch("certbot.auth_handler.AuthHandler.verify_authzr_complete") + def test_incomplete_authzr_error(self, mock_verify, mock_poll): + authzrs = [gen_dom_authzr(domain="0", challs=acme_util.CHALLENGES)] + mock_order = mock.MagicMock(authorizations=authzrs) + mock_verify.side_effect = errors.AuthorizationError + mock_poll.side_effect = self._validate_all + + self.assertRaises( + errors.AuthorizationError, self.handler.handle_authorizations, mock_order) + self.assertEqual(self.mock_auth.cleanup.call_count, 1) + self.assertEqual( + self.mock_auth.cleanup.call_args[0][0][0].typ, "tls-sni-01") + def _validate_all(self, aauthzrs, unused_1, unused_2): for i, aauthzr in enumerate(aauthzrs): azr = aauthzr.authzr diff --git a/certbot/tests/error_handler_test.py b/certbot/tests/error_handler_test.py index 60dcf5e99..d4c48c242 100644 --- a/certbot/tests/error_handler_test.py +++ b/certbot/tests/error_handler_test.py @@ -36,7 +36,7 @@ def send_signal(signum): class ErrorHandlerTest(unittest.TestCase): - """Tests for certbot.error_handler.""" + """Tests for certbot.error_handler.ErrorHandler.""" def setUp(self): from certbot import error_handler @@ -47,6 +47,7 @@ class ErrorHandlerTest(unittest.TestCase): self.handler = error_handler.ErrorHandler(self.init_func, *self.init_args, **self.init_kwargs) + # pylint: disable=protected-access self.signals = error_handler._SIGNALS @@ -113,6 +114,33 @@ class ErrorHandlerTest(unittest.TestCase): pass self.assertFalse(self.init_func.called) + def test_regular_exit(self): + func = mock.MagicMock() + self.handler.register(func) + with self.handler: + pass + self.init_func.assert_not_called() + func.assert_not_called() + + +class ExitHandlerTest(ErrorHandlerTest): + """Tests for certbot.error_handler.ExitHandler.""" + + def setUp(self): + from certbot import error_handler + super(ExitHandlerTest, self).setUp() + self.handler = error_handler.ExitHandler(self.init_func, + *self.init_args, + **self.init_kwargs) + + def test_regular_exit(self): + func = mock.MagicMock() + self.handler.register(func) + with self.handler: + pass + self.init_func.assert_called_once_with(*self.init_args, + **self.init_kwargs) + func.assert_called_once_with() if __name__ == "__main__": unittest.main() # pragma: no cover