Everything working with add_directive except save

This commit is contained in:
sydneyli 2019-05-01 11:08:06 -07:00
parent 7a79c55af8
commit fcc76618fa
4 changed files with 103 additions and 41 deletions

View file

@ -3,6 +3,7 @@ import glob
import logging
import pyparsing
from certbot import errors
from certbot.compat import os
from certbot_nginx import nginxparser
@ -91,6 +92,9 @@ class ServerBlock(obj.Block):
def __init__(self, context=None):
super(ServerBlock, self).__init__(context)
self.vhost = None
self.addrs = set()
self.ssl = False
self.server_names = set()
@staticmethod
def should_parse(lists):
@ -114,16 +118,37 @@ class ServerBlock(obj.Block):
if directive[0] == 'server_name':
self.server_names.update(x.strip('"\'') for x in directive[1:])
for ssl in self.get_directives('ssl'):
if ssl.words[1] == "on":
if ssl[1] == "on":
self.ssl = True
apply_ssl_to_all_addrs = True
if apply_ssl_to_all_addrs:
for addr in self.addrs:
addr.ssl = True
return nginx_obj.VirtualHost(
self.context.filename if self.context is not None else "",
self.addrs, self.ssl, True, self.server_names, self.dump_unspaced_list()[1],
self.get_path(), self)
self.vhost.addrs = self.addrs
self.vhost.names = self.server_names
self.vhost.ssl = self.ssl
self.vhost.raw = self.dump_unspaced_list()[1]
self.vhost.raw_obj = self
def add_directive(self, raw_list, insert_at_top=False):
""" Adds a single directive to this Server Block's contents, while enforcing
repeatability rules."""
statement = obj.parse_raw(raw_list, self.contents.child_context(), add_spaces=False)
if isinstance(statement, obj.Sentence) and statement[0] not in self.REPEATABLE_DIRECTIVES \
and len(list(self.get_directives(statement[0]))) > 0:
raise errors.MisconfigurationError(
"Existing %s directive conflicts with %s" % (statement[0], statement))
self.contents.add_directive(statement, insert_at_top)
def update_or_add_directive(self, raw_list, insert_at_top=False):
""" Adds a single directive to this Server Block's contents, while enforcing
repeatability rules."""
statement = obj.parse_raw(raw_list, self.contents.child_context(), add_spaces=False)
index = self.contents.find_directive(lambda elem: elem[0] == statement[0])
if index < 0:
self.contents.add_directive(statement, insert_at_top)
return
self.contents.update_directive(statement, index)
def get_directives(self, name, match=None):
""" Retrieves any child directive starting with `name`.
@ -138,6 +163,10 @@ class ServerBlock(obj.Block):
""" Parses lists into a ServerBlock object, and creates a
corresponding VirtualHost metadata object. """
super(ServerBlock, self).parse(raw_list, add_spaces)
self.vhost = self._update_vhost()
self.vhost = nginx_obj.VirtualHost(
self.context.filename if self.context is not None else "",
self.addrs, self.ssl, True, self.server_names, self.dump_unspaced_list()[1],
self.get_path(), self)
self._update_vhost()
NGINX_PARSING_HOOKS = (ServerBlock, obj.Block, Include, obj.Sentence, obj.Directives)

View file

@ -95,18 +95,14 @@ class NginxParser(object):
def _build_addr_to_ssl(self):
"""Builds a map from address to whether it listens on ssl in any server block
"""
servers = self._get_raw_servers()
addr_to_ssl = {} # type: Dict[Tuple[str, str], bool]
for filename in servers:
for server, _ in servers[filename]:
# Parse the server block to save addr info
parsed_server = _parse_server_raw(server)
for addr in parsed_server['addrs']:
addr_tuple = addr.normalized_tuple()
if addr_tuple not in addr_to_ssl:
addr_to_ssl[addr_tuple] = addr.ssl
addr_to_ssl[addr_tuple] = addr.ssl or addr_to_ssl[addr_tuple]
servers = self.parsed_root.get_type(nginx_obj.ServerBlock)
addr_to_ssl = {}
for server in servers:
for addr in server.vhost.addrs:
addr_tuple = addr.normalized_tuple()
if addr_tuple not in addr_to_ssl:
addr_to_ssl[addr_tuple] = addr.ssl
addr_to_ssl[addr_tuple] = addr.ssl or addr_to_ssl[addr_tuple]
return addr_to_ssl
def _get_raw_servers(self):
@ -124,7 +120,7 @@ class NginxParser(object):
_do_for_subarray(tree, lambda x: len(x) >= 2 and x[0] == ['server'],
lambda x, y: srv.append((x[1], y)))
# Find 'include' statements in server blocks and append their trees
# Find 'include statement in server blocks and append their trees
for i, (server, path) in enumerate(servers[filename]):
new_server = self._get_included_directives(server)
servers[filename][i] = (new_server, path)
@ -295,8 +291,18 @@ class NginxParser(object):
of the server block instead of the bottom
"""
self._modify_server_directives(vhost,
functools.partial(_add_directives, directives, insert_at_top))
for directive in directives:
if not _is_whitespace_or_empty(directive):
vhost.raw_obj.add_directive(directive, insert_at_top)
self._sync_old_structs(vhost)
def _sync_old_structs(self, vhost):
vhost.raw_obj._update_vhost()
old = self.parsed[vhost.filep]
for index in vhost.path[:-1]:
old = old[index]
old[vhost.path[-1]] = vhost.raw_obj.dump_unspaced_list()
def update_or_add_server_directives(self, vhost, directives, insert_at_top=False):
"""Add or replace directives in the server block identified by vhost.
@ -317,8 +323,10 @@ class NginxParser(object):
of the server block instead of the bottom
"""
self._modify_server_directives(vhost,
functools.partial(_update_or_add_directives, directives, insert_at_top))
for directive in directives:
if not _is_whitespace_or_empty(directive):
vhost.raw_obj.update_or_add_directive(directive, insert_at_top)
self._sync_old_structs(vhost)
def remove_server_directives(self, vhost, directive_name, match_func=None):
"""Remove all directives of type directive_name.
@ -389,7 +397,6 @@ class NginxParser(object):
new_directives.append(directive)
new_directives.append("\n")
raw_in_parsed[1] = new_directives
self._update_vhost_based_on_new_directives(new_vhost, new_directives)
enclosing_block.append(raw_in_parsed)
@ -564,18 +571,11 @@ def _is_ssl_on_directive(entry):
len(entry) == 2 and entry[0] == 'ssl' and
entry[1] == 'on')
def _add_directives(directives, insert_at_top, block):
"""Adds directives to a config block."""
for directive in directives:
_add_directive(block, directive, insert_at_top)
if block and '\n' not in block[-1]: # could be " \n " or ["\n"] !
block.append(nginxparser.UnspacedList('\n'))
def _update_or_add_directives(directives, insert_at_top, block):
"""Adds or replaces directives in a config block."""
for directive in directives:
_update_or_add_directive(block, directive, insert_at_top)
if block and '\n' not in block[-1]: # could be " \n " or ["\n"] !
if block and '\n' not in block.spaced[-1]: # could be " \n " or ["\n"] !
block.append(nginxparser.UnspacedList('\n'))
@ -634,6 +634,7 @@ def _is_whitespace_or_comment(directive):
"""Is this directive either a whitespace or comment directive?"""
return len(directive) == 0 or directive[0] == '#'
# block = Statements
def _add_directive(block, directive, insert_at_top):
if not isinstance(directive, nginxparser.UnspacedList):
directive = nginxparser.UnspacedList(directive)
@ -733,6 +734,13 @@ def _apply_global_addr_ssl(addr_to_ssl, parsed_server):
if addr.ssl:
parsed_server['ssl'] = True
def _is_whitespace_or_empty(directive):
if not directive:
return True
if len(directive) != 1:
return False
return len(directive[0]) == 0 or directive[0].isspace()
def _parse_server_raw(server):
"""Parses a list of server directives.

View file

@ -168,6 +168,31 @@ class Directives(Parsable):
# ======== End overridden functions
def update_directive(self, statement, index):
""" upd8
"""
self._data[index] = statement
if index + 1 >= len(self._data) or not _is_certbot_comment(self._data[index+1]):
self._data.insert(index+1, _certbot_comment(self.context))
def find_directive(self, match_func):
for i, elem in enumerate(self._data):
if isinstance(elem, Sentence) and match_func(elem):
return i
return -1
def add_directive(self, statement, insert_at_top=False):
""" Takes in a parse obj
"""
index = 0
if insert_at_top:
self._data.insert(0, statement)
else:
index = len(self._data)
self._data.append(statement)
if not _is_comment(statement):
self._data.insert(index+1, _certbot_comment(self.context))
def get_type(self, match_type):
""" TODO
"""

View file

@ -177,6 +177,7 @@ class NginxConfiguratorTest(util.NginxTest):
if name == "ipv6.com":
self.assertTrue(vhost.ipv6_enabled())
# Make sure that we have SSL enabled also for IPv6 addr
print vhost
self.assertTrue(
any([True for x in vhost.addrs if x.ssl and x.ipv6]))
@ -494,23 +495,23 @@ class NginxConfiguratorTest(util.NginxTest):
self.assertEqual(
[[['server'], [
['server_name', '.example.com'],
['server_name', 'example.*'], [],
['server_name', 'example.*'],
['listen', '5001', 'ssl'], ['#', ' managed by Certbot'],
['ssl_certificate', 'example/fullchain.pem'], ['#', ' managed by Certbot'],
['ssl_certificate_key', 'example/key.pem'], ['#', ' managed by Certbot'],
['include', self.config.mod_ssl_conf], ['#', ' managed by Certbot'],
['ssl_dhparam', self.config.ssl_dhparams], ['#', ' managed by Certbot'],
[], []]],
]],
[['server'], [
[['if', '($host', '=', 'www.example.com)'], [
['return', '301', 'https://$host$request_uri']]],
['#', ' managed by Certbot'], [],
['#', ' managed by Certbot'],
['listen', '69.50.225.155:9000'],
['listen', '127.0.0.1'],
['server_name', '.example.com'],
['server_name', 'example.*'],
['return', '404'], ['#', ' managed by Certbot'], [], [], []]]],
generated_conf)
['return', '404'], ['#', ' managed by Certbot'], ]]][0],
generated_conf[0])
def test_split_for_headers(self):
example_conf = self.config.parser.abs_path('sites-enabled/example.com')
@ -525,22 +526,21 @@ class NginxConfiguratorTest(util.NginxTest):
self.assertEqual(
[[['server'], [
['server_name', '.example.com'],
['server_name', 'example.*'], [],
['server_name', 'example.*'],
['listen', '5001', 'ssl'], ['#', ' managed by Certbot'],
['ssl_certificate', 'example/fullchain.pem'], ['#', ' managed by Certbot'],
['ssl_certificate_key', 'example/key.pem'], ['#', ' managed by Certbot'],
['include', self.config.mod_ssl_conf], ['#', ' managed by Certbot'],
['ssl_dhparam', self.config.ssl_dhparams], ['#', ' managed by Certbot'],
[], [],
['add_header', 'Strict-Transport-Security', '"max-age=31536000"', 'always'],
['#', ' managed by Certbot'],
[], []]],
]],
[['server'], [
['listen', '69.50.225.155:9000'],
['listen', '127.0.0.1'],
['server_name', '.example.com'],
['server_name', 'example.*'],
[], [], []]]],
]]],
generated_conf)
def test_http_header_hsts(self):