Fix TLS_SNI & associated tests

This commit is contained in:
Peter Eckersley 2016-06-18 14:52:07 -07:00
parent 7bcc23d9f5
commit e4f88506cc
8 changed files with 43 additions and 29 deletions

View file

@ -165,6 +165,7 @@ class UnspacedList(list):
"""Wrap a list [of lists], making any whitespace entries magically invisible"""
def __init__(self, list_source):
# ensure our argument is not a generator, and duplicate any sublists
self.spaced = copy.deepcopy(list(list_source))
# Turn self into a version of the source list that has spaces removed
@ -173,16 +174,25 @@ class UnspacedList(list):
for i, entry in reversed(list(enumerate(self))):
if isinstance(entry, list):
sublist = UnspacedList(entry)
list.__setitem__(self, i, sublist)
if sublist != [] or sublist.spaced == []:
list.__setitem__(self, i, sublist)
else:
# if a sublist is exclusively spacey entries, it might
# choke the high level parser, so make it disappear
list.__delitem__(self, i)
self.spaced[i] = sublist.spaced
elif spacey(entry):
list.__delitem__(self, i)
# don't delete comments
if "#" not in self[:i]:
list.__delitem__(self, i)
def insert(self, i, x):
if hasattr(x, "spaced"):
self.spaced.insert(i + self._spaces_before(i), x.spaced)
else:
if not isinstance(x, list): # str or None
self.spaced.insert(i + self._spaces_before(i), x)
else:
if not hasattr(x, "spaced"):
x = UnspacedList(x)
self.spaced.insert(i + self._spaces_before(i), x.spaced)
list.insert(self, i, x)
def append(self, x):

View file

@ -130,7 +130,7 @@ class VirtualHost(object): # pylint: disable=too-few-public-methods
self.names, self.ssl, self.enabled))
def __repr__(self):
return "VirtualHost(" + self.__str__().replace("\n",",") + ")\n"
return "VirtualHost(" + self.__str__().replace("\n",", ") + ")\n"
def __eq__(self, other):
if isinstance(other, self.__class__):

View file

@ -1,4 +1,5 @@
"""NginxParser is a member object of the NginxConfigurator class."""
import copy
import glob
import logging
import os
@ -113,6 +114,7 @@ class NginxParser(object):
for filename in servers:
for server in servers[filename]:
# Parse the server block into a VirtualHost object
parsed_server = parse_server(server)
vhost = obj.VirtualHost(filename,
parsed_server['addrs'],
@ -132,7 +134,7 @@ class NginxParser(object):
:rtype: list
"""
result = list(block) # Copy the list to keep self.parsed idempotent
result = copy.deepcopy(block) # Copy the list to keep self.parsed idempotent
for directive in block:
if _is_include_directive(directive):
included_files = glob.glob(
@ -465,6 +467,8 @@ def parse_server(server):
'names': set()}
for directive in server:
if not directive:
continue
if directive[0] == 'listen':
addr = obj.Addr.fromstring(directive[1])
parsed_server['addrs'].add(addr)
@ -506,7 +510,6 @@ def _add_directive(block, directive, replace):
"""
directive = nginxparser.UnspacedList(directive)
print "Unspacified", directive.spaced, directive
if len(directive) == 0:
# whitespace
block.append(directive)

View file

@ -83,7 +83,7 @@ class NginxConfiguratorTest(util.NginxTest):
filep = self.config.parser.abs_path('sites-enabled/example.com')
self.config.parser.add_server_directives(
filep, set(['.example.com', 'example.*']),
[['listen', '5001 ssl']],
[['listen', ' ', '5001 ssl']],
replace=False)
self.config.save()

View file

@ -164,9 +164,9 @@ class TestRawNginxParser(unittest.TestCase):
['#', ' Kilroy was here'],
['check_status'],
[['server'],
[['#'],
[['#', ''],
['#', " Don't forget to open up your firewall!"],
['#'],
['#', ''],
['listen', '1234'],
['#', ' listen 80;']]],
])

View file

@ -133,21 +133,16 @@ class TlsSniPerformTest(util.NginxTest):
http = self.sni.configurator.parser.parsed[
self.sni.configurator.parser.loc["root"]][-1]
print "http", http
#print "SPACED\n", http.spaced
self.assertTrue(['include', self.sni.challenge_conf] in http[1])
vhosts = self.sni.configurator.parser.get_vhosts()
print "Got", vhosts
vhs = [vh for vh in vhosts if vh.filep == self.sni.challenge_conf]
print "And now", vhs
for vhost in vhs:
if vhost.addrs == set(v_addr1):
response = self.achalls[0].response(self.account_key)
else:
response = self.achalls[2].response(self.account_key)
print vhost.addrs, set(v_addr2)
self.assertEqual(vhost.addrs, set(v_addr2))
self.assertEqual(vhost.names, set([response.z_domain]))

View file

@ -1,4 +1,5 @@
"""Common utilities for certbot_nginx."""
import copy
import os
import pkg_resources
import unittest
@ -16,6 +17,7 @@ from certbot.plugins import common
from certbot_nginx import constants
from certbot_nginx import configurator
from certbot_nginx import nginxparser
class NginxTest(unittest.TestCase): # pylint: disable=too-few-public-methods
@ -82,12 +84,15 @@ def filter_comments(tree):
def traverse(tree):
"""Generator dropping comment nodes"""
for key, values in tree:
for entry in tree:
key, values = entry
if isinstance(key, list):
yield [key, filter_comments(values)]
new = copy.deepcopy(entry)
new[1] = filter_comments(values)
yield new
else:
if key != '#':
yield [key, values]
yield entry
return list(traverse(tree))

View file

@ -93,10 +93,10 @@ class NginxTlsSni01(common.TLSSNI01):
# Add the 'include' statement for the challenges if it doesn't exist
# already in the main config
included = False
include_directive = ['include', self.challenge_conf]
include_directive = ['include', ' ', self.challenge_conf]
root = self.configurator.parser.loc["root"]
bucket_directive = ['server_names_hash_bucket_size', '128']
bucket_directive = ['server_names_hash_bucket_size', ' ', '128']
main = self.configurator.parser.parsed[root]
for key, body in main:
@ -118,6 +118,7 @@ class NginxTlsSni01(common.TLSSNI01):
config = [self._make_server_block(pair[0], pair[1])
for pair in itertools.izip(self.achalls, ll_addrs)]
config = nginxparser.UnspacedList(config)
self.configurator.reverter.register_file_creation(
True, self.challenge_conf)
@ -142,19 +143,19 @@ class NginxTlsSni01(common.TLSSNI01):
document_root = os.path.join(
self.configurator.config.work_dir, "tls_sni_01_page")
block = [['listen', str(addr)] for addr in addrs]
block = [['listen', ' ', str(addr)] for addr in addrs]
block.extend([['server_name',
block.extend([['server_name', ' ',
achall.response(achall.account_key).z_domain],
['include', self.configurator.parser.loc["ssl_options"]],
['include', ' ', self.configurator.parser.loc["ssl_options"]],
# access and error logs necessary for
# integration testing (non-root)
['access_log', os.path.join(
['access_log', ' ', os.path.join(
self.configurator.config.work_dir, 'access.log')],
['error_log', os.path.join(
['error_log', ' ', os.path.join(
self.configurator.config.work_dir, 'error.log')],
['ssl_certificate', self.get_cert_path(achall)],
['ssl_certificate_key', self.get_key_path(achall)],
[['location', '/'], [['root', document_root]]]])
['ssl_certificate', ' ', self.get_cert_path(achall)],
['ssl_certificate_key', ' ', self.get_key_path(achall)],
[['location', ' ', '/'], [['root', ' ', document_root]]]])
return [['server'], block]