diff --git a/certbot-nginx/certbot_nginx/configurator.py b/certbot-nginx/certbot_nginx/configurator.py index 40e53a0e6..0a23f1f07 100644 --- a/certbot-nginx/certbot_nginx/configurator.py +++ b/certbot-nginx/certbot_nginx/configurator.py @@ -24,7 +24,6 @@ from certbot.plugins import common from certbot_nginx import constants from certbot_nginx import tls_sni_01 -from certbot_nginx import obj from certbot_nginx import parser @@ -154,7 +153,7 @@ class NginxConfigurator(common.Plugin): ['\n', 'ssl_certificate_key', ' ', key_path]] try: - self.parser.add_server_directives(vhost.filep, vhost.names, + self.parser.add_server_directives(vhost, cert_directives, replace=True) logger.info("Deployed Certificate to VirtualHost %s for %s", vhost.filep, vhost.names) @@ -198,12 +197,9 @@ class NginxConfigurator(common.Plugin): matches = self._get_ranked_matches(target_name) if not matches: - # No matches. Create a new vhost with this name in nginx.conf. - filep = self.parser.loc["root"] - new_block = [['server'], [['\n', 'server_name', ' ', target_name]]] - self.parser.add_http_directives(filep, new_block) - vhost = obj.VirtualHost(filep, set([]), False, True, - set([target_name]), list(new_block[1])) + # No matches. Raise a misconfiguration error. + raise errors.MisconfigurationError( + "Cannot find a VirtualHost matching domain %s." % (target_name)) elif matches[0]['rank'] in xrange(2, 6): # Wildcard match - need to find the longest one rank = matches[0]['rank'] @@ -341,11 +337,7 @@ class NginxConfigurator(common.Plugin): self.parser.loc["ssl_options"]) self.parser.add_server_directives( - vhost.filep, vhost.names, ssl_block, replace=False) - vhost.ssl = True - vhost.raw.extend(ssl_block) - vhost.addrs.add(obj.Addr( - '', str(self.config.tls_sni_01_port), True, False)) + vhost, ssl_block, replace=False) def get_all_certs_keys(self): """Find all existing keys, certs from configuration. @@ -406,7 +398,7 @@ class NginxConfigurator(common.Plugin): '\n '] ], ['\n']] self.parser.add_server_directives( - vhost.filep, vhost.names, redirect_block, replace=False) + vhost, redirect_block, replace=False) logger.info("Redirecting all traffic to ssl in %s", vhost.filep) def _enable_ocsp_stapling(self, vhost, chain_path): @@ -435,7 +427,7 @@ class NginxConfigurator(common.Plugin): ['\n ', 'ssl_stapling_verify', ' ', 'on'], ['\n']] try: - self.parser.add_server_directives(vhost.filep, vhost.names, + self.parser.add_server_directives(vhost, stapling_directives, replace=False) except errors.MisconfigurationError as error: logger.debug(error) diff --git a/certbot-nginx/certbot_nginx/obj.py b/certbot-nginx/certbot_nginx/obj.py index f5ac88f6c..8c93d0a8b 100644 --- a/certbot-nginx/certbot_nginx/obj.py +++ b/certbot-nginx/certbot_nginx/obj.py @@ -107,10 +107,12 @@ class VirtualHost(object): # pylint: disable=too-few-public-methods :ivar bool ssl: SSLEngine on in vhost :ivar bool enabled: Virtual host is enabled + :ivar list path: The indices into the parsed file used to access + the server block defining the vhost """ - def __init__(self, filep, addrs, ssl, enabled, names, raw): + def __init__(self, filep, addrs, ssl, enabled, names, raw, path): # pylint: disable=too-many-arguments """Initialize a VH.""" self.filep = filep @@ -119,6 +121,7 @@ class VirtualHost(object): # pylint: disable=too-few-public-methods self.ssl = ssl self.enabled = enabled self.raw = raw + self.path = path def __str__(self): addr_str = ", ".join(str(addr) for addr in self.addrs) @@ -137,6 +140,8 @@ class VirtualHost(object): # pylint: disable=too-few-public-methods return (self.filep == other.filep and list(self.addrs) == list(other.addrs) and self.names == other.names and - self.ssl == other.ssl and self.enabled == other.enabled) + self.ssl == other.ssl and + self.enabled == other.enabled and + self.path == other.path) return False diff --git a/certbot-nginx/certbot_nginx/parser.py b/certbot-nginx/certbot_nginx/parser.py index 3919858d9..13bb38359 100644 --- a/certbot-nginx/certbot_nginx/parser.py +++ b/certbot-nginx/certbot_nginx/parser.py @@ -104,15 +104,15 @@ class NginxParser(object): # Find all the server blocks _do_for_subarray(tree, lambda x: x[0] == ['server'], - lambda x: srv.append(x[1])) + lambda x, y: srv.append((x[1], y))) # Find 'include' statements in server blocks and append their trees - for i, server in enumerate(servers[filename]): + for i, (server, path) in enumerate(servers[filename]): new_server = self._get_included_directives(server) - servers[filename][i] = new_server + servers[filename][i] = (new_server, path) for filename in servers: - for server in servers[filename]: + for server, path in servers[filename]: # Parse the server block into a VirtualHost object parsed_server = parse_server(server) @@ -121,7 +121,8 @@ class NginxParser(object): parsed_server['ssl'], enabled, parsed_server['names'], - server) + server, + path) vhosts.append(vhost) return vhosts @@ -240,42 +241,10 @@ class NginxParser(object): except IOError: logger.error("Could not open file for writing: %s", filename) - def _has_server_names(self, entry, names): - """Checks if a server block has the given set of server_names. This - is the primary way of identifying server blocks in the configurator. - Returns false if 'entry' doesn't look like a server block at all. + def add_server_directives(self, vhost, directives, replace): + """Add or replace directives in the server block identified by vhost. - ..todo :: Doesn't match server blocks whose server_name directives are - split across multiple conf files. - - :param list entry: The block to search - :param set names: The names to match - :rtype: bool - - """ - if len(names) == 0: - # Nothing to identify blocks with - return False - - if not isinstance(entry, list): - # Can't be a server block - return False - - new_entry = self._get_included_directives(entry) - server_names = set() - for item in new_entry: - if not isinstance(item, list): - # Can't be a server block - return False - - if len(item) > 0 and item[0] == 'server_name': - server_names.update(_get_servernames(item[1])) - - return server_names == names - - def add_server_directives(self, filename, names, directives, - replace): - """Add or replace directives in the first server block with names. + This method modifies vhost to be fully consistent with the new directives. ..note :: If replace is True, this raises a misconfiguration error if the directive does not already exist. @@ -285,34 +254,32 @@ class NginxParser(object): ..todo :: Doesn't match server blocks whose server_name directives are split across multiple conf files. - :param str filename: The absolute filename of the config file - :param set names: The server_name to match + :param :class:`~certbot_nginx.obj.VirtualHost` vhost: The vhost + whose information we use to match on :param list directives: The directives to add :param bool replace: Whether to only replace existing directives """ + filename = vhost.filep try: - _do_for_subarray(self.parsed[filename], - lambda x: self._has_server_names(x, names), - lambda x: _add_directives(x, directives, replace)) + result = self.parsed[filename] + for index in vhost.path: + result = result[index] + if not isinstance(result, list) or len(result) != 2: + raise errors.MisconfigurationError("Not a server block.") + result = result[1] + _add_directives(result, directives, replace) + + # update vhost based on new directives + new_server = self._get_included_directives(result) + parsed_server = parse_server(new_server) + vhost.addrs = parsed_server['addrs'] + vhost.ssl = parsed_server['ssl'] + vhost.names = parsed_server['names'] + vhost.raw = new_server except errors.MisconfigurationError as err: raise errors.MisconfigurationError("Problem in %s: %s" % (filename, err.message)) - def add_http_directives(self, filename, directives): - """Adds directives to the first encountered HTTP block in filename. - - We insert new directives at the top of the block to work around - https://trac.nginx.org/nginx/ticket/810: If the first server block - doesn't enable OCSP stapling, stapling is broken for all blocks. - - :param str filename: The absolute filename of the config file - :param list directives: The directives to add - - """ - _do_for_subarray(self.parsed[filename], - lambda x: x[0] == ['http'], - lambda x: x[1].insert(0, directives)) - def get_all_certs_keys(self): """Gets all certs and keys in the nginx config. @@ -341,7 +308,7 @@ class NginxParser(object): return c_k -def _do_for_subarray(entry, condition, func): +def _do_for_subarray(entry, condition, func, path=None): """Executes a function for a subarray of a nested array if it matches the given condition. @@ -350,12 +317,14 @@ def _do_for_subarray(entry, condition, func): :param function func: The function to call for each matching item """ + if path is None: + path = [] if isinstance(entry, list): if condition(entry): - func(entry) + func(entry, path) else: - for item in entry: - _do_for_subarray(item, condition, func) + for index, item in enumerate(entry): + _do_for_subarray(item, condition, func, path + [index]) def get_best_match(target_name, names): diff --git a/certbot-nginx/certbot_nginx/tests/configurator_test.py b/certbot-nginx/certbot_nginx/tests/configurator_test.py index 4b0117806..9bb8a46d8 100644 --- a/certbot-nginx/certbot_nginx/tests/configurator_test.py +++ b/certbot-nginx/certbot_nginx/tests/configurator_test.py @@ -13,6 +13,7 @@ from acme import messages from certbot import achallenges from certbot import errors +from certbot_nginx import obj from certbot_nginx import parser from certbot_nginx.tests import util @@ -83,8 +84,12 @@ class NginxConfiguratorTest(util.NginxTest): def test_save(self): filep = self.config.parser.abs_path('sites-enabled/example.com') + mock_vhost = obj.VirtualHost(filep, + None, None, None, + set(['.example.com', 'example.*']), + None, [0]) self.config.parser.add_server_directives( - filep, set(['.example.com', 'example.*']), + mock_vhost, [['listen', ' ', '5001 ssl']], replace=False) self.config.save() @@ -135,7 +140,8 @@ class NginxConfiguratorTest(util.NginxTest): self.assertEqual(conf_path[name], path) for name in bad_results: - self.assertEqual(set([name]), self.config.choose_vhost(name).names) + self.assertRaises(errors.MisconfigurationError, + self.config.choose_vhost, name) def test_more_info(self): self.assertTrue('nginx.conf' in self.config.more_info()) diff --git a/certbot-nginx/certbot_nginx/tests/obj_test.py b/certbot-nginx/certbot_nginx/tests/obj_test.py index e7a993d1b..200f2acb9 100644 --- a/certbot-nginx/certbot_nginx/tests/obj_test.py +++ b/certbot-nginx/certbot_nginx/tests/obj_test.py @@ -80,7 +80,7 @@ class VirtualHostTest(unittest.TestCase): self.vhost1 = VirtualHost( "filep", set([Addr.fromstring("localhost")]), False, False, - set(['localhost']), []) + set(['localhost']), [], []) def test_eq(self): from certbot_nginx.obj import Addr @@ -88,7 +88,7 @@ class VirtualHostTest(unittest.TestCase): vhost1b = VirtualHost( "filep", set([Addr.fromstring("localhost blah")]), False, False, - set(['localhost']), []) + set(['localhost']), [], []) self.assertEqual(vhost1b, self.vhost1) self.assertEqual(str(vhost1b), str(self.vhost1)) diff --git a/certbot-nginx/certbot_nginx/tests/parser_test.py b/certbot-nginx/certbot_nginx/tests/parser_test.py index 71807d4f4..18de59daf 100644 --- a/certbot-nginx/certbot_nginx/tests/parser_test.py +++ b/certbot-nginx/certbot_nginx/tests/parser_test.py @@ -79,6 +79,30 @@ class NginxParserTest(util.NginxTest): ['server_name', 'example.*']]]], parsed[0]) + def test__do_for_subarray(self): + # pylint: disable=protected-access + mylists = [([[2], [3], [2]], [[0], [2]]), + ([[2], [3], [4]], [[0]]), + ([[4], [3], [2]], [[2]]), + ([], []), + (2, []), + ([[[2], [3], [2]], [[2], [3], [2]]], + [[0, 0], [0, 2], [1, 0], [1, 2]]), + ([[[0], [3], [2]], [[2], [3], [2]]], [[0, 2], [1, 0], [1, 2]]), + ([[[0], [3], [4]], [[2], [3], [2]]], [[1, 0], [1, 2]]), + ([[[0], [3], [4]], [[5], [3], [2]]], [[1, 2]]), + ([[[0], [3], [4]], [[5], [3], [0]]], [])] + + for mylist, result in mylists: + paths = [] + parser._do_for_subarray(mylist, + lambda x: isinstance(x, list) and + len(x) >= 1 and + x[0] == 2, + lambda x, y, pts=paths: pts.append(y)) + self.assertEqual(paths, result) + + def test_get_vhosts(self): nparser = parser.NginxParser(self.config_path, self.ssl_options) vhosts = nparser.get_vhosts() @@ -88,26 +112,28 @@ class NginxParserTest(util.NginxTest): False, True, set(['localhost', r'~^(www\.)?(example|bar)\.']), - []) + [], [9, 1, 9]) vhost2 = obj.VirtualHost(nparser.abs_path('nginx.conf'), [obj.Addr('somename', '8080', False, False), obj.Addr('', '8000', False, False)], False, True, set(['somename', 'another.alias', 'alias']), - []) + [], [9, 1, 12]) vhost3 = obj.VirtualHost(nparser.abs_path('sites-enabled/example.com'), [obj.Addr('69.50.225.155', '9000', False, False), obj.Addr('127.0.0.1', '', False, False)], False, True, - set(['.example.com', 'example.*']), []) + set(['.example.com', 'example.*']), [], [0]) vhost4 = obj.VirtualHost(nparser.abs_path('sites-enabled/default'), [obj.Addr('myhost', '', False, True)], - False, True, set(['www.example.org']), []) + False, True, set(['www.example.org']), + [], [0]) vhost5 = obj.VirtualHost(nparser.abs_path('foo.conf'), [obj.Addr('*', '80', True, True)], True, True, set(['*.www.foo.com', - '*.www.example.com']), []) + '*.www.example.com']), + [], [2, 1, 0]) self.assertEqual(5, len(vhosts)) example_com = [x for x in vhosts if 'example.com' in x.filep][0] @@ -123,9 +149,12 @@ class NginxParserTest(util.NginxTest): def test_add_server_directives(self): nparser = parser.NginxParser(self.config_path, self.ssl_options) - nparser.add_server_directives(nparser.abs_path('nginx.conf'), - set(['localhost', + mock_vhost = obj.VirtualHost(nparser.abs_path('nginx.conf'), + None, None, None, + set(['localhost', r'~^(www\.)?(example|bar)\.']), + None, [9, 1, 9]) + nparser.add_server_directives(mock_vhost, [['foo', 'bar'], ['\n ', 'ssl_certificate', ' ', '/etc/ssl/cert.pem']], replace=False) @@ -133,47 +162,48 @@ class NginxParserTest(util.NginxTest): dump = nginxparser.dumps(nparser.parsed[nparser.abs_path('nginx.conf')]) self.assertEqual(1, len(re.findall(ssl_re, dump))) - server_conf = nparser.abs_path('server.conf') - names = set(['alias', 'another.alias', 'somename']) - nparser.add_server_directives(server_conf, names, + example_com = nparser.abs_path('sites-enabled/example.com') + names = set(['.example.com', 'example.*']) + mock_vhost.filep = example_com + mock_vhost.names = names + mock_vhost.path = [0] + nparser.add_server_directives(mock_vhost, [['foo', 'bar'], ['ssl_certificate', '/etc/ssl/cert2.pem']], replace=False) - nparser.add_server_directives(server_conf, names, [['foo', 'bar']], + nparser.add_server_directives(mock_vhost, [['foo', 'bar']], replace=False) from certbot_nginx.parser import COMMENT - self.assertEqual(nparser.parsed[server_conf], - [['server_name', 'somename alias another.alias'], - ['foo', 'bar'], - ['#', COMMENT], - ['ssl_certificate', '/etc/ssl/cert2.pem'], - ['#', COMMENT], - [], [] - ]) + self.assertEqual(nparser.parsed[example_com], + [[['server'], [['listen', '69.50.225.155:9000'], + ['listen', '127.0.0.1'], + ['server_name', '.example.com'], + ['server_name', 'example.*'], + ['foo', 'bar'], + ['#', COMMENT], + ['ssl_certificate', '/etc/ssl/cert2.pem'], + ['#', COMMENT], [], [] + ]]]) - def test_add_http_directives(self): - nparser = parser.NginxParser(self.config_path, self.ssl_options) - filep = nparser.abs_path('nginx.conf') - block = [['server'], - [['listen', '80'], - ['server_name', 'localhost']]] - nparser.add_http_directives(filep, block) - root = nparser.parsed[filep] - self.assertTrue(util.contains_at_depth(root, ['http'], 1)) - self.assertTrue(util.contains_at_depth(root, block, 2)) - - # Check that our server block got inserted first among all server - # blocks. - http_block = [x for x in root if x[0] == ['http']][0][1] - server_blocks = [x for x in http_block if x[0] == ['server']] - self.assertEqual(server_blocks[0], block) + server_conf = nparser.abs_path('server.conf') + names = set(['alias', 'another.alias', 'somename']) + mock_vhost.filep = server_conf + mock_vhost.names = names + mock_vhost.path = [] + self.assertRaises(errors.MisconfigurationError, + nparser.add_server_directives, + mock_vhost, + [['foo', 'bar'], + ['ssl_certificate', '/etc/ssl/cert2.pem']], + replace=False) def test_replace_server_directives(self): nparser = parser.NginxParser(self.config_path, self.ssl_options) target = set(['.example.com', 'example.*']) filep = nparser.abs_path('sites-enabled/example.com') + mock_vhost = obj.VirtualHost(filep, None, None, None, target, None, [0]) nparser.add_server_directives( - filep, target, [['server_name', 'foobar.com']], replace=True) + mock_vhost, [['server_name', 'foobar.com']], replace=True) from certbot_nginx.parser import COMMENT self.assertEqual( nparser.parsed[filep], @@ -182,9 +212,10 @@ class NginxParserTest(util.NginxTest): ['server_name', 'foobar.com'], ['#', COMMENT], ['server_name', 'example.*'], [] ]]]) + mock_vhost.names = set(['foobar.com', 'example.*']) self.assertRaises(errors.MisconfigurationError, nparser.add_server_directives, - filep, set(['foobar.com', 'example.*']), + mock_vhost, [['ssl_certificate', 'cert.pem']], replace=True) @@ -241,8 +272,11 @@ class NginxParserTest(util.NginxTest): def test_get_all_certs_keys(self): nparser = parser.NginxParser(self.config_path, self.ssl_options) filep = nparser.abs_path('sites-enabled/example.com') - nparser.add_server_directives(filep, - set(['.example.com', 'example.*']), + mock_vhost = obj.VirtualHost(filep, + None, None, None, + set(['.example.com', 'example.*']), + None, [0]) + nparser.add_server_directives(mock_vhost, [['ssl_certificate', 'foo.pem'], ['ssl_certificate_key', 'bar.key'], ['listen', '443 ssl']], diff --git a/certbot-nginx/certbot_nginx/tests/tls_sni_01_test.py b/certbot-nginx/certbot_nginx/tests/tls_sni_01_test.py index a92caf788..283e326e9 100644 --- a/certbot-nginx/certbot_nginx/tests/tls_sni_01_test.py +++ b/certbot-nginx/certbot_nginx/tests/tls_sni_01_test.py @@ -31,7 +31,7 @@ class TlsSniPerformTest(util.NginxTest): token="\xba\xa9\xda?