From fe1ba9dad68909125326c05b3b2b3f8deef572e6 Mon Sep 17 00:00:00 2001 From: yan Date: Mon, 13 Apr 2015 22:57:06 -0700 Subject: [PATCH] Add test for nginx name matching --- letsencrypt/client/plugins/nginx/parser.py | 37 +++--- .../client/plugins/nginx/tests/parser_test.py | 125 +++++++++--------- 2 files changed, 86 insertions(+), 76 deletions(-) diff --git a/letsencrypt/client/plugins/nginx/parser.py b/letsencrypt/client/plugins/nginx/parser.py index b6a75344e..52b02e9e1 100644 --- a/letsencrypt/client/plugins/nginx/parser.py +++ b/letsencrypt/client/plugins/nginx/parser.py @@ -286,9 +286,9 @@ class NginxParser(object): changed = False if len(directive) == 0: continue - for line in block: + for index, line in enumerate(block): if len(line) > 0 and line[0] == directive[0]: - line = directive + block[index] = directive changed = True if not changed: raise errors.LetsEncryptMisconfigurationError( @@ -305,7 +305,7 @@ class NginxParser(object): split across multiple conf files. :param str filename: The absolute filename of the config file - :param str names: The server_name to match + :param set names: The server_name to match :param list directives: The directives to add :param bool replace: Whether to only replace existing directives @@ -329,14 +329,11 @@ def _do_for_subarray(entry, condition, func): :param function func: The function to call for each matching item """ - for item in entry: - if type(item) == list: - if condition(item): - try: - func(item) - except: - logging.warn("Error in _do_for_subarray for %s" % item) - else: + if type(entry) == list: + if condition(entry): + func(entry) + else: + for item in entry: _do_for_subarray(item, condition, func) @@ -387,10 +384,14 @@ def get_best_match(target_name, names): def _exact_match(target_name, name): - return (target_name == name or target_name == '.' + name) + return (target_name == name or '.' + target_name == name) def _wildcard_match(target_name, name, start): + # Degenerate case + if name == '*': + return True + parts = target_name.split('.') match_parts = name.split('.') @@ -399,8 +400,12 @@ def _wildcard_match(target_name, name, start): parts.reverse() match_parts.reverse() - # The first part must be a wildcard - if match_parts.pop(0) != '*': + if len(match_parts) == 0: + return False + + # The first part must be a wildcard or blank, e.g. '.eff.org' + first = match_parts.pop(0) + if first != '*' and first != '': return False target_name = '.'.join(parts) @@ -412,13 +417,13 @@ def _wildcard_match(target_name, name, start): def _regex_match(target_name, name): # Must start with a tilde - if name[0] != '~': + if len(name) < 2 or name[0] != '~': return False # After tilde is a perl-compatible regex try: regex = re.compile(name[1:]) - if regex.match(target_name): + if re.match(regex, target_name): return True else: return False diff --git a/letsencrypt/client/plugins/nginx/tests/parser_test.py b/letsencrypt/client/plugins/nginx/tests/parser_test.py index 28fa7057e..55c7f5405 100644 --- a/letsencrypt/client/plugins/nginx/tests/parser_test.py +++ b/letsencrypt/client/plugins/nginx/tests/parser_test.py @@ -1,6 +1,7 @@ """Tests for letsencrypt.client.plugins.nginx.parser.""" import glob import os +import re import shutil import sys import unittest @@ -8,9 +9,11 @@ import unittest import zope.component from letsencrypt.client.display import util as display_util +from letsencrypt.client.errors import LetsEncryptMisconfigurationError +from letsencrypt.client.plugins.nginx.nginxparser import dumps from letsencrypt.client.plugins.nginx.obj import Addr, VirtualHost -from letsencrypt.client.plugins.nginx.parser import NginxParser +from letsencrypt.client.plugins.nginx.parser import NginxParser, get_best_match from letsencrypt.client.plugins.nginx.tests import util @@ -115,68 +118,70 @@ class NginxParserTest(util.NginxTest): self.assertEquals(vhost2, somename) def test_add_server_directives(self): - pass + parser = NginxParser(self.config_path, self.ssl_options) + parser.add_server_directives(parser.abs_path('nginx.conf'), + set(['localhost']), + [['foo', 'bar'], ['ssl_certificate', + '/etc/ssl/cert.pem']]) + r = re.compile('foo bar;\n\s+ssl_certificate /etc/ssl/cert.pem') + self.assertEqual(1, len(re.findall(r, dumps(parser.parsed[ + parser.abs_path('nginx.conf')])))) + parser.add_server_directives(parser.abs_path('server.conf'), + set(['alias', 'another.alias', + 'somename']), + [['foo', 'bar'], ['ssl_certificate', + '/etc/ssl/cert2.pem']]) + self.assertEqual(parser.parsed[parser.abs_path('server.conf')], + [['server_name', 'somename alias another.alias'], + ['foo', 'bar'], + ['ssl_certificate', '/etc/ssl/cert2.pem']]) + + def test_replace_server_directives(self): + parser = NginxParser(self.config_path, self.ssl_options) + target = set(['.example.com', 'example.*']) + filep = parser.abs_path('sites-enabled/example.com') + parser.add_server_directives( + filep, target, [['server_name', 'foo bar']], True) + self.assertEqual( + parser.parsed[filep], + [[['server'], [['listen', '9000'], ['server_name', 'foo bar'], + ['server_name', 'foo bar']]]]) + self.assertRaises(LetsEncryptMisconfigurationError, + parser.add_server_directives, + filep, set(['foo', 'bar']), + [['ssl_certificate', 'cert.pem']], True) def test_get_best_match(self): - pass + target_name = 'www.eff.org' + names = [set(['www.eff.org', 'irrelevant.long.name.eff.org', '*.org']), + set(['eff.org', 'ww2.eff.org', 'test.www.eff.org']), + set(['*.eff.org', '.www.eff.org']), + set(['.eff.org', '*.org']), + set(['www.eff.', 'www.eff.*', '*.www.eff.org']), + set(['example.com', '~^(www\.)?(eff.+)', '*.eff.*']), + set(['*', '~^(www\.)?(eff.+)']), + set(['www.*', '~^(www\.)?(eff.+)', '.test.eff.org']), + set(['*.org', '*.eff.org', 'www.eff.*']), + set(['*.www.eff.org', 'www.*']), + set(['*.org']), + set([]), + set(['example.com'])] + winners = [('exact', 'www.eff.org'), + (None, None), + ('exact', '.www.eff.org'), + ('wildcard_start', '.eff.org'), + ('wildcard_end', 'www.eff.*'), + ('regex', '~^(www\.)?(eff.+)'), + ('wildcard_start', '*'), + ('wildcard_end', 'www.*'), + ('wildcard_start', '*.eff.org'), + ('wildcard_end', 'www.*'), + ('wildcard_start', '*.org'), + (None, None), + (None, None)] -# def test_find_dir(self): -# from letsencrypt.client.plugins.nginx.parser import case_i -# test = self.parser.find_dir(case_i("Listen"), "443") -# # This will only look in enabled hosts -# test2 = self.parser.find_dir(case_i("documentroot")) -# self.assertEqual(len(test), 2) -# self.assertEqual(len(test2), 3) -# -# def test_add_dir(self): -# aug_default = "/files" + self.parser.loc["default"] -# self.parser.add_dir(aug_default, "AddDirective", "test") -# -# self.assertTrue( -# self.parser.find_dir("AddDirective", "test", aug_default)) -# -# self.parser.add_dir(aug_default, "AddList", ["1", "2", "3", "4"]) -# matches = self.parser.find_dir("AddList", None, aug_default) -# for i, match in enumerate(matches): -# self.assertEqual(self.parser.aug.get(match), str(i + 1)) -# -# def test_add_dir_to_ifmodssl(self): -# """test add_dir_to_ifmodssl. -# -# Path must be valid before attempting to add to augeas -# -# """ -# from letsencrypt.client.plugins.nginx.parser import get_aug_path -# self.parser.add_dir_to_ifmodssl( -# get_aug_path(self.parser.loc["default"]), -# "FakeDirective", "123") -# -# matches = self.parser.find_dir("FakeDirective", "123") -# -# self.assertEqual(len(matches), 1) -# self.assertTrue("IfModule" in matches[0]) -# -# def test_get_aug_path(self): -# from letsencrypt.client.plugins.nginx.parser import get_aug_path -# self.assertEqual("/files/etc/nginx", get_aug_path("/etc/nginx")) -# -# def test_set_locations(self): -# with mock.patch("letsencrypt.client.plugins.nginx.parser." -# "os.path") as mock_path: -# -# mock_path.isfile.return_value = False -# -# # pylint: disable=protected-access -# self.assertRaises(errors.LetsEncryptConfiguratorError, -# self.parser._set_locations, self.ssl_options) -# -# mock_path.isfile.side_effect = [True, False, False] -# -# # pylint: disable=protected-access -# results = self.parser._set_locations(self.ssl_options) -# -# self.assertEqual(results["default"], results["listen"]) -# self.assertEqual(results["default"], results["name"]) + for i, winner in enumerate(winners): + self.assertEqual(winner, get_best_match(target_name, names[i])) if __name__ == "__main__":