diff --git a/certbot/tests/error_handler_test.py b/certbot/tests/error_handler_test.py index 5434b36be..c2551195a 100644 --- a/certbot/tests/error_handler_test.py +++ b/certbot/tests/error_handler_test.py @@ -1,11 +1,33 @@ """Tests for certbot.error_handler.""" +import os import signal import sys import unittest +from contextlib import contextmanager import mock +@contextmanager +def signal_receiver(signums): + """Context manager to catch signals""" + def receiver(signum, unused_frame): + signals.append(signum) + signals = [] + prev_handlers = {} + for signum in signums: + prev_handlers[signum] = signal.getsignal(signum) + signal.signal(signum, receiver) + yield signals + for signum in signums: + signal.signal(signum, prev_handlers[signum]) + + +def send_signal(signum): + """Send the given signal""" + os.kill(os.getpid(), signum) + + class ErrorHandlerTest(unittest.TestCase): """Tests for certbot.error_handler.""" @@ -30,25 +52,20 @@ class ErrorHandlerTest(unittest.TestCase): self.init_func.assert_called_once_with(*self.init_args, **self.init_kwargs) - @mock.patch('certbot.error_handler.os') - @mock.patch('certbot.error_handler.signal') - def test_signal_handler(self, mock_signal, mock_os): - # pylint: disable=protected-access - mock_signal.getsignal.return_value = signal.SIG_DFL - self.handler.set_signal_handlers() - signal_handler = self.handler._signal_handler - for signum in self.signals: - mock_signal.signal.assert_any_call(signum, signal_handler) + def test_context_manager_with_signal(self): + with signal_receiver(self.signals) as signals_received: + with self.handler: + should_be_42 = 42 + send_signal(signal.SIGTERM) + should_be_42 *= 10 - signum = self.signals[0] - signal_handler(signum, None) + # check exectuion stoped when the signal was sent + assert 42 == should_be_42 + # assert signals were caught + assert [signal.SIGTERM] == signals_received + # assert the error handling function was just called once self.init_func.assert_called_once_with(*self.init_args, **self.init_kwargs) - mock_os.kill.assert_called_once_with(mock_os.getpid(), signum) - - self.handler.reset_signal_handlers() - for signum in self.signals: - mock_signal.signal.assert_any_call(signum, signal.SIG_DFL) def test_bad_recovery(self): bad_func = mock.MagicMock(side_effect=[ValueError]) @@ -58,6 +75,18 @@ class ErrorHandlerTest(unittest.TestCase): **self.init_kwargs) bad_func.assert_called_once_with() + def test_bad_recovery_with_signal(self): + bad_func = mock.MagicMock( + side_effect=lambda: send_signal(signal.SIGHUP)) + self.handler.register(bad_func) + with signal_receiver(self.signals) as signals_received: + with self.handler: + send_signal(signal.SIGTERM) + assert [signal.SIGTERM, signal.SIGHUP] == signals_received + self.init_func.assert_called_once_with(*self.init_args, + **self.init_kwargs) + bad_func.assert_called_once_with() + def test_sysexit_ignored(self): try: with self.handler: