Merge remote-tracking branch 'origin/master' into renew-symlink-safety

This commit is contained in:
Peter Eckersley 2016-09-26 14:11:30 -07:00
commit 1536d8ff65
9 changed files with 142 additions and 130 deletions

View file

@ -24,7 +24,6 @@ from certbot.plugins import common
from certbot_nginx import constants
from certbot_nginx import tls_sni_01
from certbot_nginx import obj
from certbot_nginx import parser
@ -154,7 +153,7 @@ class NginxConfigurator(common.Plugin):
['\n', 'ssl_certificate_key', ' ', key_path]]
try:
self.parser.add_server_directives(vhost.filep, vhost.names,
self.parser.add_server_directives(vhost,
cert_directives, replace=True)
logger.info("Deployed Certificate to VirtualHost %s for %s",
vhost.filep, vhost.names)
@ -198,12 +197,9 @@ class NginxConfigurator(common.Plugin):
matches = self._get_ranked_matches(target_name)
if not matches:
# No matches. Create a new vhost with this name in nginx.conf.
filep = self.parser.loc["root"]
new_block = [['server'], [['\n', 'server_name', ' ', target_name]]]
self.parser.add_http_directives(filep, new_block)
vhost = obj.VirtualHost(filep, set([]), False, True,
set([target_name]), list(new_block[1]))
# No matches. Raise a misconfiguration error.
raise errors.MisconfigurationError(
"Cannot find a VirtualHost matching domain %s." % (target_name))
elif matches[0]['rank'] in xrange(2, 6):
# Wildcard match - need to find the longest one
rank = matches[0]['rank']
@ -341,11 +337,7 @@ class NginxConfigurator(common.Plugin):
self.parser.loc["ssl_options"])
self.parser.add_server_directives(
vhost.filep, vhost.names, ssl_block, replace=False)
vhost.ssl = True
vhost.raw.extend(ssl_block)
vhost.addrs.add(obj.Addr(
'', str(self.config.tls_sni_01_port), True, False))
vhost, ssl_block, replace=False)
def get_all_certs_keys(self):
"""Find all existing keys, certs from configuration.
@ -406,7 +398,7 @@ class NginxConfigurator(common.Plugin):
'\n ']
], ['\n']]
self.parser.add_server_directives(
vhost.filep, vhost.names, redirect_block, replace=False)
vhost, redirect_block, replace=False)
logger.info("Redirecting all traffic to ssl in %s", vhost.filep)
def _enable_ocsp_stapling(self, vhost, chain_path):
@ -435,7 +427,7 @@ class NginxConfigurator(common.Plugin):
['\n ', 'ssl_stapling_verify', ' ', 'on'], ['\n']]
try:
self.parser.add_server_directives(vhost.filep, vhost.names,
self.parser.add_server_directives(vhost,
stapling_directives, replace=False)
except errors.MisconfigurationError as error:
logger.debug(error)

View file

@ -107,10 +107,12 @@ class VirtualHost(object): # pylint: disable=too-few-public-methods
:ivar bool ssl: SSLEngine on in vhost
:ivar bool enabled: Virtual host is enabled
:ivar list path: The indices into the parsed file used to access
the server block defining the vhost
"""
def __init__(self, filep, addrs, ssl, enabled, names, raw):
def __init__(self, filep, addrs, ssl, enabled, names, raw, path):
# pylint: disable=too-many-arguments
"""Initialize a VH."""
self.filep = filep
@ -119,6 +121,7 @@ class VirtualHost(object): # pylint: disable=too-few-public-methods
self.ssl = ssl
self.enabled = enabled
self.raw = raw
self.path = path
def __str__(self):
addr_str = ", ".join(str(addr) for addr in self.addrs)
@ -137,6 +140,8 @@ class VirtualHost(object): # pylint: disable=too-few-public-methods
return (self.filep == other.filep and
list(self.addrs) == list(other.addrs) and
self.names == other.names and
self.ssl == other.ssl and self.enabled == other.enabled)
self.ssl == other.ssl and
self.enabled == other.enabled and
self.path == other.path)
return False

View file

@ -104,15 +104,15 @@ class NginxParser(object):
# Find all the server blocks
_do_for_subarray(tree, lambda x: x[0] == ['server'],
lambda x: srv.append(x[1]))
lambda x, y: srv.append((x[1], y)))
# Find 'include' statements in server blocks and append their trees
for i, server in enumerate(servers[filename]):
for i, (server, path) in enumerate(servers[filename]):
new_server = self._get_included_directives(server)
servers[filename][i] = new_server
servers[filename][i] = (new_server, path)
for filename in servers:
for server in servers[filename]:
for server, path in servers[filename]:
# Parse the server block into a VirtualHost object
parsed_server = parse_server(server)
@ -121,7 +121,8 @@ class NginxParser(object):
parsed_server['ssl'],
enabled,
parsed_server['names'],
server)
server,
path)
vhosts.append(vhost)
return vhosts
@ -240,42 +241,10 @@ class NginxParser(object):
except IOError:
logger.error("Could not open file for writing: %s", filename)
def _has_server_names(self, entry, names):
"""Checks if a server block has the given set of server_names. This
is the primary way of identifying server blocks in the configurator.
Returns false if 'entry' doesn't look like a server block at all.
def add_server_directives(self, vhost, directives, replace):
"""Add or replace directives in the server block identified by vhost.
..todo :: Doesn't match server blocks whose server_name directives are
split across multiple conf files.
:param list entry: The block to search
:param set names: The names to match
:rtype: bool
"""
if len(names) == 0:
# Nothing to identify blocks with
return False
if not isinstance(entry, list):
# Can't be a server block
return False
new_entry = self._get_included_directives(entry)
server_names = set()
for item in new_entry:
if not isinstance(item, list):
# Can't be a server block
return False
if len(item) > 0 and item[0] == 'server_name':
server_names.update(_get_servernames(item[1]))
return server_names == names
def add_server_directives(self, filename, names, directives,
replace):
"""Add or replace directives in the first server block with names.
This method modifies vhost to be fully consistent with the new directives.
..note :: If replace is True, this raises a misconfiguration error
if the directive does not already exist.
@ -285,34 +254,32 @@ class NginxParser(object):
..todo :: Doesn't match server blocks whose server_name directives are
split across multiple conf files.
:param str filename: The absolute filename of the config file
:param set names: The server_name to match
:param :class:`~certbot_nginx.obj.VirtualHost` vhost: The vhost
whose information we use to match on
:param list directives: The directives to add
:param bool replace: Whether to only replace existing directives
"""
filename = vhost.filep
try:
_do_for_subarray(self.parsed[filename],
lambda x: self._has_server_names(x, names),
lambda x: _add_directives(x, directives, replace))
result = self.parsed[filename]
for index in vhost.path:
result = result[index]
if not isinstance(result, list) or len(result) != 2:
raise errors.MisconfigurationError("Not a server block.")
result = result[1]
_add_directives(result, directives, replace)
# update vhost based on new directives
new_server = self._get_included_directives(result)
parsed_server = parse_server(new_server)
vhost.addrs = parsed_server['addrs']
vhost.ssl = parsed_server['ssl']
vhost.names = parsed_server['names']
vhost.raw = new_server
except errors.MisconfigurationError as err:
raise errors.MisconfigurationError("Problem in %s: %s" % (filename, err.message))
def add_http_directives(self, filename, directives):
"""Adds directives to the first encountered HTTP block in filename.
We insert new directives at the top of the block to work around
https://trac.nginx.org/nginx/ticket/810: If the first server block
doesn't enable OCSP stapling, stapling is broken for all blocks.
:param str filename: The absolute filename of the config file
:param list directives: The directives to add
"""
_do_for_subarray(self.parsed[filename],
lambda x: x[0] == ['http'],
lambda x: x[1].insert(0, directives))
def get_all_certs_keys(self):
"""Gets all certs and keys in the nginx config.
@ -341,7 +308,7 @@ class NginxParser(object):
return c_k
def _do_for_subarray(entry, condition, func):
def _do_for_subarray(entry, condition, func, path=None):
"""Executes a function for a subarray of a nested array if it matches
the given condition.
@ -350,12 +317,14 @@ def _do_for_subarray(entry, condition, func):
:param function func: The function to call for each matching item
"""
if path is None:
path = []
if isinstance(entry, list):
if condition(entry):
func(entry)
func(entry, path)
else:
for item in entry:
_do_for_subarray(item, condition, func)
for index, item in enumerate(entry):
_do_for_subarray(item, condition, func, path + [index])
def get_best_match(target_name, names):

View file

@ -13,6 +13,7 @@ from acme import messages
from certbot import achallenges
from certbot import errors
from certbot_nginx import obj
from certbot_nginx import parser
from certbot_nginx.tests import util
@ -83,8 +84,12 @@ class NginxConfiguratorTest(util.NginxTest):
def test_save(self):
filep = self.config.parser.abs_path('sites-enabled/example.com')
mock_vhost = obj.VirtualHost(filep,
None, None, None,
set(['.example.com', 'example.*']),
None, [0])
self.config.parser.add_server_directives(
filep, set(['.example.com', 'example.*']),
mock_vhost,
[['listen', ' ', '5001 ssl']],
replace=False)
self.config.save()
@ -135,7 +140,8 @@ class NginxConfiguratorTest(util.NginxTest):
self.assertEqual(conf_path[name], path)
for name in bad_results:
self.assertEqual(set([name]), self.config.choose_vhost(name).names)
self.assertRaises(errors.MisconfigurationError,
self.config.choose_vhost, name)
def test_more_info(self):
self.assertTrue('nginx.conf' in self.config.more_info())

View file

@ -80,7 +80,7 @@ class VirtualHostTest(unittest.TestCase):
self.vhost1 = VirtualHost(
"filep",
set([Addr.fromstring("localhost")]), False, False,
set(['localhost']), [])
set(['localhost']), [], [])
def test_eq(self):
from certbot_nginx.obj import Addr
@ -88,7 +88,7 @@ class VirtualHostTest(unittest.TestCase):
vhost1b = VirtualHost(
"filep",
set([Addr.fromstring("localhost blah")]), False, False,
set(['localhost']), [])
set(['localhost']), [], [])
self.assertEqual(vhost1b, self.vhost1)
self.assertEqual(str(vhost1b), str(self.vhost1))

View file

@ -79,6 +79,30 @@ class NginxParserTest(util.NginxTest):
['server_name', 'example.*']]]],
parsed[0])
def test__do_for_subarray(self):
# pylint: disable=protected-access
mylists = [([[2], [3], [2]], [[0], [2]]),
([[2], [3], [4]], [[0]]),
([[4], [3], [2]], [[2]]),
([], []),
(2, []),
([[[2], [3], [2]], [[2], [3], [2]]],
[[0, 0], [0, 2], [1, 0], [1, 2]]),
([[[0], [3], [2]], [[2], [3], [2]]], [[0, 2], [1, 0], [1, 2]]),
([[[0], [3], [4]], [[2], [3], [2]]], [[1, 0], [1, 2]]),
([[[0], [3], [4]], [[5], [3], [2]]], [[1, 2]]),
([[[0], [3], [4]], [[5], [3], [0]]], [])]
for mylist, result in mylists:
paths = []
parser._do_for_subarray(mylist,
lambda x: isinstance(x, list) and
len(x) >= 1 and
x[0] == 2,
lambda x, y, pts=paths: pts.append(y))
self.assertEqual(paths, result)
def test_get_vhosts(self):
nparser = parser.NginxParser(self.config_path, self.ssl_options)
vhosts = nparser.get_vhosts()
@ -88,26 +112,28 @@ class NginxParserTest(util.NginxTest):
False, True,
set(['localhost',
r'~^(www\.)?(example|bar)\.']),
[])
[], [9, 1, 9])
vhost2 = obj.VirtualHost(nparser.abs_path('nginx.conf'),
[obj.Addr('somename', '8080', False, False),
obj.Addr('', '8000', False, False)],
False, True,
set(['somename', 'another.alias', 'alias']),
[])
[], [9, 1, 12])
vhost3 = obj.VirtualHost(nparser.abs_path('sites-enabled/example.com'),
[obj.Addr('69.50.225.155', '9000',
False, False),
obj.Addr('127.0.0.1', '', False, False)],
False, True,
set(['.example.com', 'example.*']), [])
set(['.example.com', 'example.*']), [], [0])
vhost4 = obj.VirtualHost(nparser.abs_path('sites-enabled/default'),
[obj.Addr('myhost', '', False, True)],
False, True, set(['www.example.org']), [])
False, True, set(['www.example.org']),
[], [0])
vhost5 = obj.VirtualHost(nparser.abs_path('foo.conf'),
[obj.Addr('*', '80', True, True)],
True, True, set(['*.www.foo.com',
'*.www.example.com']), [])
'*.www.example.com']),
[], [2, 1, 0])
self.assertEqual(5, len(vhosts))
example_com = [x for x in vhosts if 'example.com' in x.filep][0]
@ -123,9 +149,12 @@ class NginxParserTest(util.NginxTest):
def test_add_server_directives(self):
nparser = parser.NginxParser(self.config_path, self.ssl_options)
nparser.add_server_directives(nparser.abs_path('nginx.conf'),
set(['localhost',
mock_vhost = obj.VirtualHost(nparser.abs_path('nginx.conf'),
None, None, None,
set(['localhost',
r'~^(www\.)?(example|bar)\.']),
None, [9, 1, 9])
nparser.add_server_directives(mock_vhost,
[['foo', 'bar'], ['\n ', 'ssl_certificate', ' ',
'/etc/ssl/cert.pem']],
replace=False)
@ -133,47 +162,48 @@ class NginxParserTest(util.NginxTest):
dump = nginxparser.dumps(nparser.parsed[nparser.abs_path('nginx.conf')])
self.assertEqual(1, len(re.findall(ssl_re, dump)))
server_conf = nparser.abs_path('server.conf')
names = set(['alias', 'another.alias', 'somename'])
nparser.add_server_directives(server_conf, names,
example_com = nparser.abs_path('sites-enabled/example.com')
names = set(['.example.com', 'example.*'])
mock_vhost.filep = example_com
mock_vhost.names = names
mock_vhost.path = [0]
nparser.add_server_directives(mock_vhost,
[['foo', 'bar'], ['ssl_certificate',
'/etc/ssl/cert2.pem']],
replace=False)
nparser.add_server_directives(server_conf, names, [['foo', 'bar']],
nparser.add_server_directives(mock_vhost, [['foo', 'bar']],
replace=False)
from certbot_nginx.parser import COMMENT
self.assertEqual(nparser.parsed[server_conf],
[['server_name', 'somename alias another.alias'],
['foo', 'bar'],
['#', COMMENT],
['ssl_certificate', '/etc/ssl/cert2.pem'],
['#', COMMENT],
[], []
])
self.assertEqual(nparser.parsed[example_com],
[[['server'], [['listen', '69.50.225.155:9000'],
['listen', '127.0.0.1'],
['server_name', '.example.com'],
['server_name', 'example.*'],
['foo', 'bar'],
['#', COMMENT],
['ssl_certificate', '/etc/ssl/cert2.pem'],
['#', COMMENT], [], []
]]])
def test_add_http_directives(self):
nparser = parser.NginxParser(self.config_path, self.ssl_options)
filep = nparser.abs_path('nginx.conf')
block = [['server'],
[['listen', '80'],
['server_name', 'localhost']]]
nparser.add_http_directives(filep, block)
root = nparser.parsed[filep]
self.assertTrue(util.contains_at_depth(root, ['http'], 1))
self.assertTrue(util.contains_at_depth(root, block, 2))
# Check that our server block got inserted first among all server
# blocks.
http_block = [x for x in root if x[0] == ['http']][0][1]
server_blocks = [x for x in http_block if x[0] == ['server']]
self.assertEqual(server_blocks[0], block)
server_conf = nparser.abs_path('server.conf')
names = set(['alias', 'another.alias', 'somename'])
mock_vhost.filep = server_conf
mock_vhost.names = names
mock_vhost.path = []
self.assertRaises(errors.MisconfigurationError,
nparser.add_server_directives,
mock_vhost,
[['foo', 'bar'],
['ssl_certificate', '/etc/ssl/cert2.pem']],
replace=False)
def test_replace_server_directives(self):
nparser = parser.NginxParser(self.config_path, self.ssl_options)
target = set(['.example.com', 'example.*'])
filep = nparser.abs_path('sites-enabled/example.com')
mock_vhost = obj.VirtualHost(filep, None, None, None, target, None, [0])
nparser.add_server_directives(
filep, target, [['server_name', 'foobar.com']], replace=True)
mock_vhost, [['server_name', 'foobar.com']], replace=True)
from certbot_nginx.parser import COMMENT
self.assertEqual(
nparser.parsed[filep],
@ -182,9 +212,10 @@ class NginxParserTest(util.NginxTest):
['server_name', 'foobar.com'], ['#', COMMENT],
['server_name', 'example.*'], []
]]])
mock_vhost.names = set(['foobar.com', 'example.*'])
self.assertRaises(errors.MisconfigurationError,
nparser.add_server_directives,
filep, set(['foobar.com', 'example.*']),
mock_vhost,
[['ssl_certificate', 'cert.pem']],
replace=True)
@ -241,8 +272,11 @@ class NginxParserTest(util.NginxTest):
def test_get_all_certs_keys(self):
nparser = parser.NginxParser(self.config_path, self.ssl_options)
filep = nparser.abs_path('sites-enabled/example.com')
nparser.add_server_directives(filep,
set(['.example.com', 'example.*']),
mock_vhost = obj.VirtualHost(filep,
None, None, None,
set(['.example.com', 'example.*']),
None, [0])
nparser.add_server_directives(mock_vhost,
[['ssl_certificate', 'foo.pem'],
['ssl_certificate_key', 'bar.key'],
['listen', '443 ssl']],

View file

@ -31,7 +31,7 @@ class TlsSniPerformTest(util.NginxTest):
token="\xba\xa9\xda?<m\xaewmx\xea\xad\xadv\xf4\x02\xc9y"
"\x80\xe2_X\t\xe7\xc7\xa4\t\xca\xf7&\x945"
), "pending"),
domain="blah", account_key=account_key),
domain="another.alias", account_key=account_key),
achallenges.KeyAuthorizationAnnotatedChallenge(
challb=acme_util.chall_to_challb(
challenges.TLSSNI01(
@ -109,8 +109,8 @@ class TlsSniPerformTest(util.NginxTest):
http = self.sni.configurator.parser.parsed[
self.sni.configurator.parser.loc["root"]][-1]
self.assertTrue(['include', self.sni.challenge_conf] in http[1])
self.assertTrue(
util.contains_at_depth(http, ['server_name', 'blah'], 3))
self.assertFalse(
util.contains_at_depth(http, ['server_name', 'another.alias'], 3))
self.assertEqual(len(sni_responses), 3)
for i in xrange(3):

View file

@ -52,6 +52,7 @@ http {
listen 8081;
# IPv6.
listen [::]:8081 default ipv6only=on;
server_name nginx.wtf;
root $root/webroot;

View file

@ -15,9 +15,14 @@ logger = logging.getLogger(__name__)
# potentially occur from inside Python. Signals such as SIGILL were not
# included as they could be a sign of something devious and we should terminate
# immediately.
_SIGNALS = ([signal.SIGTERM] if os.name == "nt" else
[signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT,
signal.SIGXCPU, signal.SIGXFSZ])
_SIGNALS = [signal.SIGTERM]
if os.name != "nt":
for signal_code in [signal.SIGHUP, signal.SIGQUIT,
signal.SIGXCPU, signal.SIGXFSZ]:
# Adding only those signals that their default action is not Ignore.
# This is platform-dependent, so we check it dynamically.
if signal.getsignal(signal_code) != signal.SIG_IGN:
_SIGNALS.append(signal_code)
class ErrorHandler(object):