diff --git a/certbot-nginx/certbot_nginx/tests/challenges_test.py b/certbot-nginx/certbot_nginx/tests/challenges_test.py index 5d49b6999..cf7c5df08 100644 --- a/certbot-nginx/certbot_nginx/tests/challenges_test.py +++ b/certbot-nginx/certbot_nginx/tests/challenges_test.py @@ -17,9 +17,127 @@ from certbot_nginx import obj from certbot_nginx.tests import util -class TlsSniPerformTest(util.NginxTest): - """Test the NginxTlsSni01 challenge.""" +class ChallengePerformTest(object): + """Abstract base class.""" + def tearDown(self): + shutil.rmtree(self.temp_dir) + shutil.rmtree(self.config_dir) + shutil.rmtree(self.work_dir) + + @mock.patch("certbot_nginx.configurator" + ".NginxConfigurator.choose_vhost") + def test_perform(self, mock_choose): + self.chall_doer.add_chall(self.achalls[1]) + mock_choose.return_value = None + result = self.chall_doer.perform() + self.assertFalse(result is None) + + def test_perform0(self): + responses = self.chall_doer.perform() + self.assertEqual([], responses) + + @mock.patch("certbot_nginx.configurator.NginxConfigurator.save") + def test_perform1(self, mock_save): + self.chall_doer.add_chall(self.achalls[0]) + response = self.achalls[0].response(self.account_key) + mock_setup_cert = mock.MagicMock(return_value=response) + + # pylint: disable=protected-access + if hasattr(self.chall_doer, '_setup_challenge_cert'): + self.chall_doer._setup_challenge_cert = mock_setup_cert + + responses = self.chall_doer.perform() + + if hasattr(self.chall_doer, '_setup_challenge_cert'): + mock_setup_cert.assert_called_once_with(self.achalls[0]) + self.assertEqual([response], responses) + self.assertEqual(mock_save.call_count, 1) + + # Make sure challenge config is included in main config + http = self.chall_doer.configurator.parser.parsed[ + self.chall_doer.configurator.parser.config_root][-1] + self.assertTrue( + util.contains_at_depth(http, ['include', self.chall_doer.challenge_conf], 1)) + + def test_perform2(self): + acme_responses = [] + for achall in self.achalls: + self.chall_doer.add_chall(achall) + acme_responses.append(achall.response(self.account_key)) + + mock_setup_cert = mock.MagicMock(side_effect=acme_responses) + # pylint: disable=protected-access + if hasattr(self.chall_doer, '_setup_challenge_cert'): + self.chall_doer._setup_challenge_cert = mock_setup_cert + + sni_responses = self.chall_doer.perform() + + if hasattr(self.chall_doer, '_setup_challenge_cert'): + self.assertEqual(mock_setup_cert.call_count, 4) + + for index, achall in enumerate(self.achalls): + self.assertEqual( + mock_setup_cert.call_args_list[index], mock.call(achall)) + + http = self.chall_doer.configurator.parser.parsed[ + self.chall_doer.configurator.parser.config_root][-1] + self.assertTrue(['include', self.chall_doer.challenge_conf] in http[1]) + self.assertFalse( + util.contains_at_depth(http, ['server_name', 'another.alias'], 3)) + + self.assertEqual(len(sni_responses), 4) + for i in six.moves.range(4): + self.assertEqual(sni_responses[i], acme_responses[i]) + + def test_mod_config(self): + self.chall_doer.add_chall(self.achalls[0]) + self.chall_doer.add_chall(self.achalls[2]) + + v_addr1 = [obj.Addr("69.50.225.155", "9000", True, False, False, False), + obj.Addr("127.0.0.1", "", False, False, False, False)] + v_addr2 = [obj.Addr("myhost", "", False, True, False, False)] + v_addr2_print = [obj.Addr("myhost", "", False, False, False, False)] + ll_addr = [v_addr1, v_addr2] + self.chall_doer._mod_config(ll_addr) # pylint: disable=protected-access + + self.chall_doer.configurator.save() + + self.chall_doer.configurator.parser.load() + + http = self.chall_doer.configurator.parser.parsed[ + self.chall_doer.configurator.parser.config_root][-1] + self.assertTrue(['include', self.chall_doer.challenge_conf] in http[1]) + + vhosts = self.chall_doer.configurator.parser.get_vhosts() + vhs = [vh for vh in vhosts if vh.filep == self.chall_doer.challenge_conf] + + for vhost in vhs: + if vhost.addrs == set(v_addr1): + achall = self.achalls[0] + response = achall.response(self.account_key) + else: + achall = self.achalls[2] + response = achall.response(self.account_key) + self.assertEqual(vhost.addrs, set(v_addr2_print)) + if hasattr(response, 'z_domain'): + domain = response.z_domain.decode('ascii') + else: + domain = achall.domain + self.assertEqual(vhost.names, set([domain])) + + self.assertEqual(len(vhs), 2) + + def test_mod_config_fail(self): + root = self.chall_doer.configurator.parser.config_root + self.chall_doer.configurator.parser.parsed[root] = [['include', 'foo.conf']] + # pylint: disable=protected-access + self.assertRaises( + errors.MisconfigurationError, self.chall_doer._mod_config, []) + + +class TlsSniPerformTest(util.NginxTest, ChallengePerformTest): + """Test the NginxTlsSni01 challenge.""" account_key = common_test.AUTH_KEY achalls = [ achallenges.KeyAuthorizationAnnotatedChallenge( @@ -47,118 +165,52 @@ class TlsSniPerformTest(util.NginxTest): ] def setUp(self): - super(TlsSniPerformTest, self).setUp() + util.NginxTest.setUp(self) config = util.get_nginx_configurator( self.config_path, self.config_dir, self.work_dir, self.logs_dir) from certbot_nginx import challenges as nginx_challenges - self.sni = nginx_challenges.NginxTlsSni01(config) + self.chall_doer = nginx_challenges.NginxTlsSni01(config) - def tearDown(self): - shutil.rmtree(self.temp_dir) - shutil.rmtree(self.config_dir) - shutil.rmtree(self.work_dir) - @mock.patch("certbot_nginx.configurator" - ".NginxConfigurator.choose_vhost") - def test_perform(self, mock_choose): - self.sni.add_chall(self.achalls[1]) - mock_choose.return_value = None - result = self.sni.perform() - self.assertFalse(result is None) +class HttpPerformTest(util.NginxTest, ChallengePerformTest): + """Test the NginxHttp01 challenge.""" + account_key = common_test.AUTH_KEY + achalls = [ + achallenges.KeyAuthorizationAnnotatedChallenge( + challb=acme_util.chall_to_challb( + challenges.HTTP01(token=b"kNdwjwOeX0I_A8DXt9Msmg"), "pending"), + domain="www.example.com", account_key=account_key), + achallenges.KeyAuthorizationAnnotatedChallenge( + challb=acme_util.chall_to_challb( + challenges.HTTP01( + token=b"\xba\xa9\xda?