fix(auth_handler): cleanup is always called (#5779)

* fix(auth_handler): cleanup is always called

* test(auth_handler): tests for various error cases
This commit is contained in:
sydneyli 2018-03-26 17:09:02 -07:00 committed by Brad Warren
parent 804fd4b78a
commit af2cce4ca8
4 changed files with 92 additions and 25 deletions

View file

@ -69,14 +69,15 @@ class AuthHandler(object):
# While there are still challenges remaining...
while self._has_challenges(aauthzrs):
resp = self._solve_challenges(aauthzrs)
logger.info("Waiting for verification...")
if config.debug_challenges:
notify('Challenges loaded. Press continue to submit to CA. '
'Pass "-v" for more info about challenges.', pause=True)
with error_handler.ExitHandler(self._cleanup_challenges, aauthzrs):
resp = self._solve_challenges(aauthzrs)
logger.info("Waiting for verification...")
if config.debug_challenges:
notify('Challenges loaded. Press continue to submit to CA. '
'Pass "-v" for more info about challenges.', pause=True)
# Send all Responses - this modifies achalls
self._respond(aauthzrs, resp, best_effort)
# Send all Responses - this modifies achalls
self._respond(aauthzrs, resp, best_effort)
# Just make sure all decisions are complete.
self.verify_authzr_complete(aauthzrs)
@ -118,14 +119,13 @@ class AuthHandler(object):
"""Get Responses for challenges from authenticators."""
resp = []
all_achalls = self._get_all_achalls(aauthzrs)
with error_handler.ErrorHandler(self._cleanup_challenges, aauthzrs, all_achalls):
try:
if all_achalls:
resp = self.auth.perform(all_achalls)
except errors.AuthorizationError:
logger.critical("Failure in setting up challenges.")
logger.info("Attempting to clean up outstanding challenges...")
raise
try:
if all_achalls:
resp = self.auth.perform(all_achalls)
except errors.AuthorizationError:
logger.critical("Failure in setting up challenges.")
logger.info("Attempting to clean up outstanding challenges...")
raise
assert len(resp) == len(all_achalls)
@ -147,13 +147,10 @@ class AuthHandler(object):
"""
# TODO: chall_update is a dirty hack to get around acme-spec #105
chall_update = dict()
active_achalls = self._send_responses(aauthzrs, resp, chall_update)
self._send_responses(aauthzrs, resp, chall_update)
# Check for updated status...
try:
self._poll_challenges(aauthzrs, chall_update, best_effort)
finally:
self._cleanup_challenges(aauthzrs, active_achalls)
self._poll_challenges(aauthzrs, chall_update, best_effort)
def _send_responses(self, aauthzrs, resps, chall_update):
"""Send responses and make sure errors are handled.
@ -294,7 +291,7 @@ class AuthHandler(object):
chall_prefs.extend(plugin_pref)
return chall_prefs
def _cleanup_challenges(self, aauthzrs, achalls):
def _cleanup_challenges(self, aauthzrs, achalls=None):
"""Cleanup challenges.
:param aauthzrs: authorizations and their selected annotated
@ -305,7 +302,8 @@ class AuthHandler(object):
"""
logger.info("Cleaning up challenges")
if achalls is None:
achalls = self._get_all_achalls(aauthzrs)
if achalls:
self.auth.cleanup(achalls)
for achall in achalls:

View file

@ -24,7 +24,6 @@ if os.name != "nt":
if signal.getsignal(signal_code) != signal.SIG_IGN:
_SIGNALS.append(signal_code)
class ErrorHandler(object):
"""Context manager for running code that must be cleaned up on failure.
@ -55,6 +54,7 @@ class ErrorHandler(object):
"""
def __init__(self, func=None, *args, **kwargs):
self.call_on_regular_exit = False
self.body_executed = False
self.funcs = []
self.prev_handlers = {}
@ -70,8 +70,11 @@ class ErrorHandler(object):
self.body_executed = True
retval = False
# SystemExit is ignored to properly handle forks that don't exec
if exec_type in (None, SystemExit):
if exec_type is SystemExit:
return retval
elif exec_type is None:
if not self.call_on_regular_exit:
return retval
elif exec_type is errors.SignalExit:
logger.debug("Encountered signals: %s", self.received_signals)
retval = True
@ -136,3 +139,15 @@ class ErrorHandler(object):
for signum in self.received_signals:
logger.debug("Calling signal %s", signum)
os.kill(os.getpid(), signum)
class ExitHandler(ErrorHandler):
"""Context manager for running code that must be cleaned up.
Subclass of ErrorHandler, with the same usage and parameters.
In addition to cleaning up on all signals, also cleans up on
regular exit.
"""
def __init__(self, func=None, *args, **kwargs):
ErrorHandler.__init__(self, func, *args, **kwargs)
self.call_on_regular_exit = True

View file

@ -289,6 +289,32 @@ class HandleAuthorizationsTest(unittest.TestCase):
self.assertEqual(
self.mock_auth.cleanup.call_args[0][0][0].typ, "tls-sni-01")
@mock.patch("certbot.auth_handler.AuthHandler._respond")
def test_respond_error(self, mock_respond):
authzrs = [gen_dom_authzr(domain="0", challs=acme_util.CHALLENGES)]
mock_order = mock.MagicMock(authorizations=authzrs)
mock_respond.side_effect = errors.AuthorizationError
self.assertRaises(
errors.AuthorizationError, self.handler.handle_authorizations, mock_order)
self.assertEqual(self.mock_auth.cleanup.call_count, 1)
self.assertEqual(
self.mock_auth.cleanup.call_args[0][0][0].typ, "tls-sni-01")
@mock.patch("certbot.auth_handler.AuthHandler._poll_challenges")
@mock.patch("certbot.auth_handler.AuthHandler.verify_authzr_complete")
def test_incomplete_authzr_error(self, mock_verify, mock_poll):
authzrs = [gen_dom_authzr(domain="0", challs=acme_util.CHALLENGES)]
mock_order = mock.MagicMock(authorizations=authzrs)
mock_verify.side_effect = errors.AuthorizationError
mock_poll.side_effect = self._validate_all
self.assertRaises(
errors.AuthorizationError, self.handler.handle_authorizations, mock_order)
self.assertEqual(self.mock_auth.cleanup.call_count, 1)
self.assertEqual(
self.mock_auth.cleanup.call_args[0][0][0].typ, "tls-sni-01")
def _validate_all(self, aauthzrs, unused_1, unused_2):
for i, aauthzr in enumerate(aauthzrs):
azr = aauthzr.authzr

View file

@ -36,7 +36,7 @@ def send_signal(signum):
class ErrorHandlerTest(unittest.TestCase):
"""Tests for certbot.error_handler."""
"""Tests for certbot.error_handler.ErrorHandler."""
def setUp(self):
from certbot import error_handler
@ -47,6 +47,7 @@ class ErrorHandlerTest(unittest.TestCase):
self.handler = error_handler.ErrorHandler(self.init_func,
*self.init_args,
**self.init_kwargs)
# pylint: disable=protected-access
self.signals = error_handler._SIGNALS
@ -113,6 +114,33 @@ class ErrorHandlerTest(unittest.TestCase):
pass
self.assertFalse(self.init_func.called)
def test_regular_exit(self):
func = mock.MagicMock()
self.handler.register(func)
with self.handler:
pass
self.init_func.assert_not_called()
func.assert_not_called()
class ExitHandlerTest(ErrorHandlerTest):
"""Tests for certbot.error_handler.ExitHandler."""
def setUp(self):
from certbot import error_handler
super(ExitHandlerTest, self).setUp()
self.handler = error_handler.ExitHandler(self.init_func,
*self.init_args,
**self.init_kwargs)
def test_regular_exit(self):
func = mock.MagicMock()
self.handler.register(func)
with self.handler:
pass
self.init_func.assert_called_once_with(*self.init_args,
**self.init_kwargs)
func.assert_called_once_with()
if __name__ == "__main__":
unittest.main() # pragma: no cover