From d737546dd709fe5a1f3b8b99ca973b4d3e08a2dc Mon Sep 17 00:00:00 2001 From: Liam Marshall Date: Fri, 20 Nov 2015 16:31:01 -0600 Subject: [PATCH] Split off cleaning into a method (fixes a subtle bug) --- .../letsencrypt_apache/configurator.py | 18 +++--- .../tests/configurator_test.py | 55 +++++++++---------- 2 files changed, 36 insertions(+), 37 deletions(-) diff --git a/letsencrypt-apache/letsencrypt_apache/configurator.py b/letsencrypt-apache/letsencrypt_apache/configurator.py index d80d27d1c..ff95eef95 100644 --- a/letsencrypt-apache/letsencrypt_apache/configurator.py +++ b/letsencrypt-apache/letsencrypt_apache/configurator.py @@ -183,6 +183,7 @@ class ApacheConfigurator(augeas_configurator.AugeasConfigurator): """ vhost = self.choose_vhost(domain) + self._clean_vhost(vhost) # This is done first so that ssl module is enabled and cert_path, # cert_key... can all be parsed appropriately @@ -276,15 +277,7 @@ class ApacheConfigurator(augeas_configurator.AugeasConfigurator): self.assoc[target_name] = vhost return vhost - vhost = self._choose_vhost_from_list(target_name) - if vhost.ssl: - # remove duplicated or conflicting ssl directives - self._deduplicate_directives(vhost.path, - ["SSLCertificateFile", "SSLCertificateKeyFile"]) - # remove all problematic directives - self._remove_directives(vhost.path, ["SSLCertificateChainFile"]) - - return vhost + return self._choose_vhost_from_list(target_name) def _choose_vhost_from_list(self, target_name): # Select a vhost from a list @@ -665,6 +658,13 @@ class ApacheConfigurator(augeas_configurator.AugeasConfigurator): return ssl_addrs + def _clean_vhost(self, vhost): + # remove duplicated or conflicting ssl directives + self._deduplicate_directives(vhost.path, + ["SSLCertificateFile", "SSLCertificateKeyFile"]) + # remove all problematic directives + self._remove_directives(vhost.path, ["SSLCertificateChainFile"]) + def _deduplicate_directives(self, vh_path, directives): for directive in directives: while len(self.parser.find_dir(directive, None, vh_path, False)) > 1: diff --git a/letsencrypt-apache/letsencrypt_apache/tests/configurator_test.py b/letsencrypt-apache/letsencrypt_apache/tests/configurator_test.py index 58aac1216..d5ea540c5 100644 --- a/letsencrypt-apache/letsencrypt_apache/tests/configurator_test.py +++ b/letsencrypt-apache/letsencrypt_apache/tests/configurator_test.py @@ -122,34 +122,6 @@ class TwoVhost80Test(util.ApacheTest): self.assertEqual( self.vh_truth[1], self.config.choose_vhost("none.com")) - @mock.patch("letsencrypt_apache.display_ops.select_vhost") - def test_choose_vhost_cleans_vhost_ssl(self, mock_select): - for directive in ["SSLCertificateFile", "SSLCertificateKeyFile", - "SSLCertificateChainFile", "SSLCACertificatePath"]: - for _ in range(10): - self.config.parser.add_dir(self.vh_truth[1].path, directive, ["bogus"]) - self.config.save() - mock_select.return_value = self.vh_truth[1] - - self.config.choose_vhost("none.com") - self.config.save() - - loc_cert = self.config.parser.find_dir( - 'SSLCertificateFile', None, self.vh_truth[1].path, False) - loc_key = self.config.parser.find_dir( - 'SSLCertificateKeyFile', None, self.vh_truth[1].path, False) - loc_chain = self.config.parser.find_dir( - 'SSLCertificateChainFile', None, self.vh_truth[1].path, False) - loc_cacert = self.config.parser.find_dir( - 'SSLCACertificatePath', None, self.vh_truth[1].path, False) - - self.assertEqual(len(loc_cert), 1) - self.assertEqual(len(loc_key), 1) - - self.assertEqual(len(loc_chain), 0) - - self.assertEqual(len(loc_cacert), 10) - @mock.patch("letsencrypt_apache.display_ops.select_vhost") def test_choose_vhost_select_vhost_non_ssl(self, mock_select): mock_select.return_value = self.vh_truth[0] @@ -433,6 +405,33 @@ class TwoVhost80Test(util.ApacheTest): self.assertEqual(len(self.config.vhosts), 5) + def test_clean_vhost_ssl(self): + # pylint: disable=protected-access + for directive in ["SSLCertificateFile", "SSLCertificateKeyFile", + "SSLCertificateChainFile", "SSLCACertificatePath"]: + for _ in range(10): + self.config.parser.add_dir(self.vh_truth[1].path, directive, ["bogus"]) + self.config.save() + + self.config._clean_vhost(self.vh_truth[1]) + self.config.save() + + loc_cert = self.config.parser.find_dir( + 'SSLCertificateFile', None, self.vh_truth[1].path, False) + loc_key = self.config.parser.find_dir( + 'SSLCertificateKeyFile', None, self.vh_truth[1].path, False) + loc_chain = self.config.parser.find_dir( + 'SSLCertificateChainFile', None, self.vh_truth[1].path, False) + loc_cacert = self.config.parser.find_dir( + 'SSLCACertificatePath', None, self.vh_truth[1].path, False) + + self.assertEqual(len(loc_cert), 1) + self.assertEqual(len(loc_key), 1) + + self.assertEqual(len(loc_chain), 0) + + self.assertEqual(len(loc_cacert), 10) + def test_deduplicate_directives(self): # pylint: disable=protected-access DIRECTIVE = "Foo"