mirror of
https://github.com/certbot/certbot.git
synced 2026-06-14 19:20:09 -04:00
Everything working with add_directive except save
This commit is contained in:
parent
7a79c55af8
commit
fcc76618fa
4 changed files with 103 additions and 41 deletions
|
|
@ -3,6 +3,7 @@ import glob
|
|||
import logging
|
||||
import pyparsing
|
||||
|
||||
from certbot import errors
|
||||
from certbot.compat import os
|
||||
|
||||
from certbot_nginx import nginxparser
|
||||
|
|
@ -91,6 +92,9 @@ class ServerBlock(obj.Block):
|
|||
def __init__(self, context=None):
|
||||
super(ServerBlock, self).__init__(context)
|
||||
self.vhost = None
|
||||
self.addrs = set()
|
||||
self.ssl = False
|
||||
self.server_names = set()
|
||||
|
||||
@staticmethod
|
||||
def should_parse(lists):
|
||||
|
|
@ -114,16 +118,37 @@ class ServerBlock(obj.Block):
|
|||
if directive[0] == 'server_name':
|
||||
self.server_names.update(x.strip('"\'') for x in directive[1:])
|
||||
for ssl in self.get_directives('ssl'):
|
||||
if ssl.words[1] == "on":
|
||||
if ssl[1] == "on":
|
||||
self.ssl = True
|
||||
apply_ssl_to_all_addrs = True
|
||||
if apply_ssl_to_all_addrs:
|
||||
for addr in self.addrs:
|
||||
addr.ssl = True
|
||||
return nginx_obj.VirtualHost(
|
||||
self.context.filename if self.context is not None else "",
|
||||
self.addrs, self.ssl, True, self.server_names, self.dump_unspaced_list()[1],
|
||||
self.get_path(), self)
|
||||
self.vhost.addrs = self.addrs
|
||||
self.vhost.names = self.server_names
|
||||
self.vhost.ssl = self.ssl
|
||||
self.vhost.raw = self.dump_unspaced_list()[1]
|
||||
self.vhost.raw_obj = self
|
||||
|
||||
def add_directive(self, raw_list, insert_at_top=False):
|
||||
""" Adds a single directive to this Server Block's contents, while enforcing
|
||||
repeatability rules."""
|
||||
statement = obj.parse_raw(raw_list, self.contents.child_context(), add_spaces=False)
|
||||
if isinstance(statement, obj.Sentence) and statement[0] not in self.REPEATABLE_DIRECTIVES \
|
||||
and len(list(self.get_directives(statement[0]))) > 0:
|
||||
raise errors.MisconfigurationError(
|
||||
"Existing %s directive conflicts with %s" % (statement[0], statement))
|
||||
self.contents.add_directive(statement, insert_at_top)
|
||||
|
||||
def update_or_add_directive(self, raw_list, insert_at_top=False):
|
||||
""" Adds a single directive to this Server Block's contents, while enforcing
|
||||
repeatability rules."""
|
||||
statement = obj.parse_raw(raw_list, self.contents.child_context(), add_spaces=False)
|
||||
index = self.contents.find_directive(lambda elem: elem[0] == statement[0])
|
||||
if index < 0:
|
||||
self.contents.add_directive(statement, insert_at_top)
|
||||
return
|
||||
self.contents.update_directive(statement, index)
|
||||
|
||||
def get_directives(self, name, match=None):
|
||||
""" Retrieves any child directive starting with `name`.
|
||||
|
|
@ -138,6 +163,10 @@ class ServerBlock(obj.Block):
|
|||
""" Parses lists into a ServerBlock object, and creates a
|
||||
corresponding VirtualHost metadata object. """
|
||||
super(ServerBlock, self).parse(raw_list, add_spaces)
|
||||
self.vhost = self._update_vhost()
|
||||
self.vhost = nginx_obj.VirtualHost(
|
||||
self.context.filename if self.context is not None else "",
|
||||
self.addrs, self.ssl, True, self.server_names, self.dump_unspaced_list()[1],
|
||||
self.get_path(), self)
|
||||
self._update_vhost()
|
||||
|
||||
NGINX_PARSING_HOOKS = (ServerBlock, obj.Block, Include, obj.Sentence, obj.Directives)
|
||||
|
|
|
|||
|
|
@ -95,18 +95,14 @@ class NginxParser(object):
|
|||
def _build_addr_to_ssl(self):
|
||||
"""Builds a map from address to whether it listens on ssl in any server block
|
||||
"""
|
||||
servers = self._get_raw_servers()
|
||||
|
||||
addr_to_ssl = {} # type: Dict[Tuple[str, str], bool]
|
||||
for filename in servers:
|
||||
for server, _ in servers[filename]:
|
||||
# Parse the server block to save addr info
|
||||
parsed_server = _parse_server_raw(server)
|
||||
for addr in parsed_server['addrs']:
|
||||
addr_tuple = addr.normalized_tuple()
|
||||
if addr_tuple not in addr_to_ssl:
|
||||
addr_to_ssl[addr_tuple] = addr.ssl
|
||||
addr_to_ssl[addr_tuple] = addr.ssl or addr_to_ssl[addr_tuple]
|
||||
servers = self.parsed_root.get_type(nginx_obj.ServerBlock)
|
||||
addr_to_ssl = {}
|
||||
for server in servers:
|
||||
for addr in server.vhost.addrs:
|
||||
addr_tuple = addr.normalized_tuple()
|
||||
if addr_tuple not in addr_to_ssl:
|
||||
addr_to_ssl[addr_tuple] = addr.ssl
|
||||
addr_to_ssl[addr_tuple] = addr.ssl or addr_to_ssl[addr_tuple]
|
||||
return addr_to_ssl
|
||||
|
||||
def _get_raw_servers(self):
|
||||
|
|
@ -124,7 +120,7 @@ class NginxParser(object):
|
|||
_do_for_subarray(tree, lambda x: len(x) >= 2 and x[0] == ['server'],
|
||||
lambda x, y: srv.append((x[1], y)))
|
||||
|
||||
# Find 'include' statements in server blocks and append their trees
|
||||
# Find 'include statement in server blocks and append their trees
|
||||
for i, (server, path) in enumerate(servers[filename]):
|
||||
new_server = self._get_included_directives(server)
|
||||
servers[filename][i] = (new_server, path)
|
||||
|
|
@ -295,8 +291,18 @@ class NginxParser(object):
|
|||
of the server block instead of the bottom
|
||||
|
||||
"""
|
||||
self._modify_server_directives(vhost,
|
||||
functools.partial(_add_directives, directives, insert_at_top))
|
||||
|
||||
for directive in directives:
|
||||
if not _is_whitespace_or_empty(directive):
|
||||
vhost.raw_obj.add_directive(directive, insert_at_top)
|
||||
self._sync_old_structs(vhost)
|
||||
|
||||
def _sync_old_structs(self, vhost):
|
||||
vhost.raw_obj._update_vhost()
|
||||
old = self.parsed[vhost.filep]
|
||||
for index in vhost.path[:-1]:
|
||||
old = old[index]
|
||||
old[vhost.path[-1]] = vhost.raw_obj.dump_unspaced_list()
|
||||
|
||||
def update_or_add_server_directives(self, vhost, directives, insert_at_top=False):
|
||||
"""Add or replace directives in the server block identified by vhost.
|
||||
|
|
@ -317,8 +323,10 @@ class NginxParser(object):
|
|||
of the server block instead of the bottom
|
||||
|
||||
"""
|
||||
self._modify_server_directives(vhost,
|
||||
functools.partial(_update_or_add_directives, directives, insert_at_top))
|
||||
for directive in directives:
|
||||
if not _is_whitespace_or_empty(directive):
|
||||
vhost.raw_obj.update_or_add_directive(directive, insert_at_top)
|
||||
self._sync_old_structs(vhost)
|
||||
|
||||
def remove_server_directives(self, vhost, directive_name, match_func=None):
|
||||
"""Remove all directives of type directive_name.
|
||||
|
|
@ -389,7 +397,6 @@ class NginxParser(object):
|
|||
new_directives.append(directive)
|
||||
new_directives.append("\n")
|
||||
raw_in_parsed[1] = new_directives
|
||||
|
||||
self._update_vhost_based_on_new_directives(new_vhost, new_directives)
|
||||
|
||||
enclosing_block.append(raw_in_parsed)
|
||||
|
|
@ -564,18 +571,11 @@ def _is_ssl_on_directive(entry):
|
|||
len(entry) == 2 and entry[0] == 'ssl' and
|
||||
entry[1] == 'on')
|
||||
|
||||
def _add_directives(directives, insert_at_top, block):
|
||||
"""Adds directives to a config block."""
|
||||
for directive in directives:
|
||||
_add_directive(block, directive, insert_at_top)
|
||||
if block and '\n' not in block[-1]: # could be " \n " or ["\n"] !
|
||||
block.append(nginxparser.UnspacedList('\n'))
|
||||
|
||||
def _update_or_add_directives(directives, insert_at_top, block):
|
||||
"""Adds or replaces directives in a config block."""
|
||||
for directive in directives:
|
||||
_update_or_add_directive(block, directive, insert_at_top)
|
||||
if block and '\n' not in block[-1]: # could be " \n " or ["\n"] !
|
||||
if block and '\n' not in block.spaced[-1]: # could be " \n " or ["\n"] !
|
||||
block.append(nginxparser.UnspacedList('\n'))
|
||||
|
||||
|
||||
|
|
@ -634,6 +634,7 @@ def _is_whitespace_or_comment(directive):
|
|||
"""Is this directive either a whitespace or comment directive?"""
|
||||
return len(directive) == 0 or directive[0] == '#'
|
||||
|
||||
# block = Statements
|
||||
def _add_directive(block, directive, insert_at_top):
|
||||
if not isinstance(directive, nginxparser.UnspacedList):
|
||||
directive = nginxparser.UnspacedList(directive)
|
||||
|
|
@ -733,6 +734,13 @@ def _apply_global_addr_ssl(addr_to_ssl, parsed_server):
|
|||
if addr.ssl:
|
||||
parsed_server['ssl'] = True
|
||||
|
||||
def _is_whitespace_or_empty(directive):
|
||||
if not directive:
|
||||
return True
|
||||
if len(directive) != 1:
|
||||
return False
|
||||
return len(directive[0]) == 0 or directive[0].isspace()
|
||||
|
||||
def _parse_server_raw(server):
|
||||
"""Parses a list of server directives.
|
||||
|
||||
|
|
|
|||
|
|
@ -168,6 +168,31 @@ class Directives(Parsable):
|
|||
|
||||
# ======== End overridden functions
|
||||
|
||||
def update_directive(self, statement, index):
|
||||
""" upd8
|
||||
"""
|
||||
self._data[index] = statement
|
||||
if index + 1 >= len(self._data) or not _is_certbot_comment(self._data[index+1]):
|
||||
self._data.insert(index+1, _certbot_comment(self.context))
|
||||
|
||||
def find_directive(self, match_func):
|
||||
for i, elem in enumerate(self._data):
|
||||
if isinstance(elem, Sentence) and match_func(elem):
|
||||
return i
|
||||
return -1
|
||||
|
||||
def add_directive(self, statement, insert_at_top=False):
|
||||
""" Takes in a parse obj
|
||||
"""
|
||||
index = 0
|
||||
if insert_at_top:
|
||||
self._data.insert(0, statement)
|
||||
else:
|
||||
index = len(self._data)
|
||||
self._data.append(statement)
|
||||
if not _is_comment(statement):
|
||||
self._data.insert(index+1, _certbot_comment(self.context))
|
||||
|
||||
def get_type(self, match_type):
|
||||
""" TODO
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -177,6 +177,7 @@ class NginxConfiguratorTest(util.NginxTest):
|
|||
if name == "ipv6.com":
|
||||
self.assertTrue(vhost.ipv6_enabled())
|
||||
# Make sure that we have SSL enabled also for IPv6 addr
|
||||
print vhost
|
||||
self.assertTrue(
|
||||
any([True for x in vhost.addrs if x.ssl and x.ipv6]))
|
||||
|
||||
|
|
@ -494,23 +495,23 @@ class NginxConfiguratorTest(util.NginxTest):
|
|||
self.assertEqual(
|
||||
[[['server'], [
|
||||
['server_name', '.example.com'],
|
||||
['server_name', 'example.*'], [],
|
||||
['server_name', 'example.*'],
|
||||
['listen', '5001', 'ssl'], ['#', ' managed by Certbot'],
|
||||
['ssl_certificate', 'example/fullchain.pem'], ['#', ' managed by Certbot'],
|
||||
['ssl_certificate_key', 'example/key.pem'], ['#', ' managed by Certbot'],
|
||||
['include', self.config.mod_ssl_conf], ['#', ' managed by Certbot'],
|
||||
['ssl_dhparam', self.config.ssl_dhparams], ['#', ' managed by Certbot'],
|
||||
[], []]],
|
||||
]],
|
||||
[['server'], [
|
||||
[['if', '($host', '=', 'www.example.com)'], [
|
||||
['return', '301', 'https://$host$request_uri']]],
|
||||
['#', ' managed by Certbot'], [],
|
||||
['#', ' managed by Certbot'],
|
||||
['listen', '69.50.225.155:9000'],
|
||||
['listen', '127.0.0.1'],
|
||||
['server_name', '.example.com'],
|
||||
['server_name', 'example.*'],
|
||||
['return', '404'], ['#', ' managed by Certbot'], [], [], []]]],
|
||||
generated_conf)
|
||||
['return', '404'], ['#', ' managed by Certbot'], ]]][0],
|
||||
generated_conf[0])
|
||||
|
||||
def test_split_for_headers(self):
|
||||
example_conf = self.config.parser.abs_path('sites-enabled/example.com')
|
||||
|
|
@ -525,22 +526,21 @@ class NginxConfiguratorTest(util.NginxTest):
|
|||
self.assertEqual(
|
||||
[[['server'], [
|
||||
['server_name', '.example.com'],
|
||||
['server_name', 'example.*'], [],
|
||||
['server_name', 'example.*'],
|
||||
['listen', '5001', 'ssl'], ['#', ' managed by Certbot'],
|
||||
['ssl_certificate', 'example/fullchain.pem'], ['#', ' managed by Certbot'],
|
||||
['ssl_certificate_key', 'example/key.pem'], ['#', ' managed by Certbot'],
|
||||
['include', self.config.mod_ssl_conf], ['#', ' managed by Certbot'],
|
||||
['ssl_dhparam', self.config.ssl_dhparams], ['#', ' managed by Certbot'],
|
||||
[], [],
|
||||
['add_header', 'Strict-Transport-Security', '"max-age=31536000"', 'always'],
|
||||
['#', ' managed by Certbot'],
|
||||
[], []]],
|
||||
]],
|
||||
[['server'], [
|
||||
['listen', '69.50.225.155:9000'],
|
||||
['listen', '127.0.0.1'],
|
||||
['server_name', '.example.com'],
|
||||
['server_name', 'example.*'],
|
||||
[], [], []]]],
|
||||
]]],
|
||||
generated_conf)
|
||||
|
||||
def test_http_header_hsts(self):
|
||||
|
|
|
|||
Loading…
Reference in a new issue