From fcc76618faa84ccf7341491fb8dc3134462ad65e Mon Sep 17 00:00:00 2001 From: sydneyli Date: Wed, 1 May 2019 11:08:06 -0700 Subject: [PATCH] Everything working with add_directive except save --- .../certbot_nginx/nginx_parser_obj.py | 41 +++++++++++-- certbot-nginx/certbot_nginx/parser.py | 60 +++++++++++-------- certbot-nginx/certbot_nginx/parser_obj.py | 25 ++++++++ .../certbot_nginx/tests/configurator_test.py | 18 +++--- 4 files changed, 103 insertions(+), 41 deletions(-) diff --git a/certbot-nginx/certbot_nginx/nginx_parser_obj.py b/certbot-nginx/certbot_nginx/nginx_parser_obj.py index 452831312..5e53b71ca 100644 --- a/certbot-nginx/certbot_nginx/nginx_parser_obj.py +++ b/certbot-nginx/certbot_nginx/nginx_parser_obj.py @@ -3,6 +3,7 @@ import glob import logging import pyparsing +from certbot import errors from certbot.compat import os from certbot_nginx import nginxparser @@ -91,6 +92,9 @@ class ServerBlock(obj.Block): def __init__(self, context=None): super(ServerBlock, self).__init__(context) self.vhost = None + self.addrs = set() + self.ssl = False + self.server_names = set() @staticmethod def should_parse(lists): @@ -114,16 +118,37 @@ class ServerBlock(obj.Block): if directive[0] == 'server_name': self.server_names.update(x.strip('"\'') for x in directive[1:]) for ssl in self.get_directives('ssl'): - if ssl.words[1] == "on": + if ssl[1] == "on": self.ssl = True apply_ssl_to_all_addrs = True if apply_ssl_to_all_addrs: for addr in self.addrs: addr.ssl = True - return nginx_obj.VirtualHost( - self.context.filename if self.context is not None else "", - self.addrs, self.ssl, True, self.server_names, self.dump_unspaced_list()[1], - self.get_path(), self) + self.vhost.addrs = self.addrs + self.vhost.names = self.server_names + self.vhost.ssl = self.ssl + self.vhost.raw = self.dump_unspaced_list()[1] + self.vhost.raw_obj = self + + def add_directive(self, raw_list, insert_at_top=False): + """ Adds a single directive to this Server Block's contents, while enforcing + repeatability rules.""" + statement = obj.parse_raw(raw_list, self.contents.child_context(), add_spaces=False) + if isinstance(statement, obj.Sentence) and statement[0] not in self.REPEATABLE_DIRECTIVES \ + and len(list(self.get_directives(statement[0]))) > 0: + raise errors.MisconfigurationError( + "Existing %s directive conflicts with %s" % (statement[0], statement)) + self.contents.add_directive(statement, insert_at_top) + + def update_or_add_directive(self, raw_list, insert_at_top=False): + """ Adds a single directive to this Server Block's contents, while enforcing + repeatability rules.""" + statement = obj.parse_raw(raw_list, self.contents.child_context(), add_spaces=False) + index = self.contents.find_directive(lambda elem: elem[0] == statement[0]) + if index < 0: + self.contents.add_directive(statement, insert_at_top) + return + self.contents.update_directive(statement, index) def get_directives(self, name, match=None): """ Retrieves any child directive starting with `name`. @@ -138,6 +163,10 @@ class ServerBlock(obj.Block): """ Parses lists into a ServerBlock object, and creates a corresponding VirtualHost metadata object. """ super(ServerBlock, self).parse(raw_list, add_spaces) - self.vhost = self._update_vhost() + self.vhost = nginx_obj.VirtualHost( + self.context.filename if self.context is not None else "", + self.addrs, self.ssl, True, self.server_names, self.dump_unspaced_list()[1], + self.get_path(), self) + self._update_vhost() NGINX_PARSING_HOOKS = (ServerBlock, obj.Block, Include, obj.Sentence, obj.Directives) diff --git a/certbot-nginx/certbot_nginx/parser.py b/certbot-nginx/certbot_nginx/parser.py index 0f5881b9c..b49166432 100644 --- a/certbot-nginx/certbot_nginx/parser.py +++ b/certbot-nginx/certbot_nginx/parser.py @@ -95,18 +95,14 @@ class NginxParser(object): def _build_addr_to_ssl(self): """Builds a map from address to whether it listens on ssl in any server block """ - servers = self._get_raw_servers() - - addr_to_ssl = {} # type: Dict[Tuple[str, str], bool] - for filename in servers: - for server, _ in servers[filename]: - # Parse the server block to save addr info - parsed_server = _parse_server_raw(server) - for addr in parsed_server['addrs']: - addr_tuple = addr.normalized_tuple() - if addr_tuple not in addr_to_ssl: - addr_to_ssl[addr_tuple] = addr.ssl - addr_to_ssl[addr_tuple] = addr.ssl or addr_to_ssl[addr_tuple] + servers = self.parsed_root.get_type(nginx_obj.ServerBlock) + addr_to_ssl = {} + for server in servers: + for addr in server.vhost.addrs: + addr_tuple = addr.normalized_tuple() + if addr_tuple not in addr_to_ssl: + addr_to_ssl[addr_tuple] = addr.ssl + addr_to_ssl[addr_tuple] = addr.ssl or addr_to_ssl[addr_tuple] return addr_to_ssl def _get_raw_servers(self): @@ -124,7 +120,7 @@ class NginxParser(object): _do_for_subarray(tree, lambda x: len(x) >= 2 and x[0] == ['server'], lambda x, y: srv.append((x[1], y))) - # Find 'include' statements in server blocks and append their trees + # Find 'include statement in server blocks and append their trees for i, (server, path) in enumerate(servers[filename]): new_server = self._get_included_directives(server) servers[filename][i] = (new_server, path) @@ -295,8 +291,18 @@ class NginxParser(object): of the server block instead of the bottom """ - self._modify_server_directives(vhost, - functools.partial(_add_directives, directives, insert_at_top)) + + for directive in directives: + if not _is_whitespace_or_empty(directive): + vhost.raw_obj.add_directive(directive, insert_at_top) + self._sync_old_structs(vhost) + + def _sync_old_structs(self, vhost): + vhost.raw_obj._update_vhost() + old = self.parsed[vhost.filep] + for index in vhost.path[:-1]: + old = old[index] + old[vhost.path[-1]] = vhost.raw_obj.dump_unspaced_list() def update_or_add_server_directives(self, vhost, directives, insert_at_top=False): """Add or replace directives in the server block identified by vhost. @@ -317,8 +323,10 @@ class NginxParser(object): of the server block instead of the bottom """ - self._modify_server_directives(vhost, - functools.partial(_update_or_add_directives, directives, insert_at_top)) + for directive in directives: + if not _is_whitespace_or_empty(directive): + vhost.raw_obj.update_or_add_directive(directive, insert_at_top) + self._sync_old_structs(vhost) def remove_server_directives(self, vhost, directive_name, match_func=None): """Remove all directives of type directive_name. @@ -389,7 +397,6 @@ class NginxParser(object): new_directives.append(directive) new_directives.append("\n") raw_in_parsed[1] = new_directives - self._update_vhost_based_on_new_directives(new_vhost, new_directives) enclosing_block.append(raw_in_parsed) @@ -564,18 +571,11 @@ def _is_ssl_on_directive(entry): len(entry) == 2 and entry[0] == 'ssl' and entry[1] == 'on') -def _add_directives(directives, insert_at_top, block): - """Adds directives to a config block.""" - for directive in directives: - _add_directive(block, directive, insert_at_top) - if block and '\n' not in block[-1]: # could be " \n " or ["\n"] ! - block.append(nginxparser.UnspacedList('\n')) - def _update_or_add_directives(directives, insert_at_top, block): """Adds or replaces directives in a config block.""" for directive in directives: _update_or_add_directive(block, directive, insert_at_top) - if block and '\n' not in block[-1]: # could be " \n " or ["\n"] ! + if block and '\n' not in block.spaced[-1]: # could be " \n " or ["\n"] ! block.append(nginxparser.UnspacedList('\n')) @@ -634,6 +634,7 @@ def _is_whitespace_or_comment(directive): """Is this directive either a whitespace or comment directive?""" return len(directive) == 0 or directive[0] == '#' +# block = Statements def _add_directive(block, directive, insert_at_top): if not isinstance(directive, nginxparser.UnspacedList): directive = nginxparser.UnspacedList(directive) @@ -733,6 +734,13 @@ def _apply_global_addr_ssl(addr_to_ssl, parsed_server): if addr.ssl: parsed_server['ssl'] = True +def _is_whitespace_or_empty(directive): + if not directive: + return True + if len(directive) != 1: + return False + return len(directive[0]) == 0 or directive[0].isspace() + def _parse_server_raw(server): """Parses a list of server directives. diff --git a/certbot-nginx/certbot_nginx/parser_obj.py b/certbot-nginx/certbot_nginx/parser_obj.py index 266d70c14..38c1f9f17 100644 --- a/certbot-nginx/certbot_nginx/parser_obj.py +++ b/certbot-nginx/certbot_nginx/parser_obj.py @@ -168,6 +168,31 @@ class Directives(Parsable): # ======== End overridden functions + def update_directive(self, statement, index): + """ upd8 + """ + self._data[index] = statement + if index + 1 >= len(self._data) or not _is_certbot_comment(self._data[index+1]): + self._data.insert(index+1, _certbot_comment(self.context)) + + def find_directive(self, match_func): + for i, elem in enumerate(self._data): + if isinstance(elem, Sentence) and match_func(elem): + return i + return -1 + + def add_directive(self, statement, insert_at_top=False): + """ Takes in a parse obj + """ + index = 0 + if insert_at_top: + self._data.insert(0, statement) + else: + index = len(self._data) + self._data.append(statement) + if not _is_comment(statement): + self._data.insert(index+1, _certbot_comment(self.context)) + def get_type(self, match_type): """ TODO """ diff --git a/certbot-nginx/certbot_nginx/tests/configurator_test.py b/certbot-nginx/certbot_nginx/tests/configurator_test.py index 73fb2eba5..90af8085d 100644 --- a/certbot-nginx/certbot_nginx/tests/configurator_test.py +++ b/certbot-nginx/certbot_nginx/tests/configurator_test.py @@ -177,6 +177,7 @@ class NginxConfiguratorTest(util.NginxTest): if name == "ipv6.com": self.assertTrue(vhost.ipv6_enabled()) # Make sure that we have SSL enabled also for IPv6 addr + print vhost self.assertTrue( any([True for x in vhost.addrs if x.ssl and x.ipv6])) @@ -494,23 +495,23 @@ class NginxConfiguratorTest(util.NginxTest): self.assertEqual( [[['server'], [ ['server_name', '.example.com'], - ['server_name', 'example.*'], [], + ['server_name', 'example.*'], ['listen', '5001', 'ssl'], ['#', ' managed by Certbot'], ['ssl_certificate', 'example/fullchain.pem'], ['#', ' managed by Certbot'], ['ssl_certificate_key', 'example/key.pem'], ['#', ' managed by Certbot'], ['include', self.config.mod_ssl_conf], ['#', ' managed by Certbot'], ['ssl_dhparam', self.config.ssl_dhparams], ['#', ' managed by Certbot'], - [], []]], + ]], [['server'], [ [['if', '($host', '=', 'www.example.com)'], [ ['return', '301', 'https://$host$request_uri']]], - ['#', ' managed by Certbot'], [], + ['#', ' managed by Certbot'], ['listen', '69.50.225.155:9000'], ['listen', '127.0.0.1'], ['server_name', '.example.com'], ['server_name', 'example.*'], - ['return', '404'], ['#', ' managed by Certbot'], [], [], []]]], - generated_conf) + ['return', '404'], ['#', ' managed by Certbot'], ]]][0], + generated_conf[0]) def test_split_for_headers(self): example_conf = self.config.parser.abs_path('sites-enabled/example.com') @@ -525,22 +526,21 @@ class NginxConfiguratorTest(util.NginxTest): self.assertEqual( [[['server'], [ ['server_name', '.example.com'], - ['server_name', 'example.*'], [], + ['server_name', 'example.*'], ['listen', '5001', 'ssl'], ['#', ' managed by Certbot'], ['ssl_certificate', 'example/fullchain.pem'], ['#', ' managed by Certbot'], ['ssl_certificate_key', 'example/key.pem'], ['#', ' managed by Certbot'], ['include', self.config.mod_ssl_conf], ['#', ' managed by Certbot'], ['ssl_dhparam', self.config.ssl_dhparams], ['#', ' managed by Certbot'], - [], [], ['add_header', 'Strict-Transport-Security', '"max-age=31536000"', 'always'], ['#', ' managed by Certbot'], - [], []]], + ]], [['server'], [ ['listen', '69.50.225.155:9000'], ['listen', '127.0.0.1'], ['server_name', '.example.com'], ['server_name', 'example.*'], - [], [], []]]], + ]]], generated_conf) def test_http_header_hsts(self):