Mark Nginx vhosts as ssl when any vhost is on ssl at that address (#3856)

* Move parse_server to be a method of NginxParser

* add super equal method to more correctly check addr equality in nginx should we support ipv6 in nginx in the future

* add addr:normalized_tuple method

* mark addresses listening sslishly due to another server block listening sslishly on that address

* test turning on ssl globally

* add docstring

* lint and remove extra file
This commit is contained in:
Erica Portnoy 2016-12-05 19:17:04 -08:00 committed by Peter Eckersley
parent 3dbf5c9fcb
commit f0a7bb0e33
6 changed files with 158 additions and 70 deletions

View file

@ -93,9 +93,16 @@ class Addr(common.Addr):
def __repr__(self):
return "Addr(" + self.__str__() + ")"
def super_eq(self, other):
"""Check ip/port equality, with IPv6 support.
"""
# Nginx plugin currently doesn't support IPv6 but this will
# future-proof it
return super(Addr, self).__eq__(other)
def __eq__(self, other):
if isinstance(other, self.__class__):
return (self.tup == other.tup and
return (self.super_eq(other) and
self.ssl == other.ssl and
self.default == other.default)
return False

View file

@ -82,21 +82,28 @@ class NginxParser(object):
else:
return path
def get_vhosts(self):
# pylint: disable=cell-var-from-loop
"""Gets list of all 'virtual hosts' found in Nginx configuration.
Technically this is a misnomer because Nginx does not have virtual
hosts, it has 'server blocks'.
:returns: List of :class:`~certbot_nginx.obj.VirtualHost`
objects found in configuration
:rtype: list
def _build_addr_to_ssl(self):
"""Builds a map from address to whether it listens on ssl in any server block
"""
enabled = True # We only look at enabled vhosts for now
vhosts = []
servers = {}
servers = self._get_raw_servers()
addr_to_ssl = {}
for filename in servers:
for server, _ in servers[filename]:
# Parse the server block to save addr info
parsed_server = _parse_server_raw(server)
for addr in parsed_server['addrs']:
addr_tuple = addr.normalized_tuple()
if addr_tuple not in addr_to_ssl:
addr_to_ssl[addr_tuple] = addr.ssl
addr_to_ssl[addr_tuple] = addr.ssl or addr_to_ssl[addr_tuple]
return addr_to_ssl
def _get_raw_servers(self):
# pylint: disable=cell-var-from-loop
"""Get a map of unparsed all server blocks
"""
servers = {}
for filename in self.parsed:
tree = self.parsed[filename]
servers[filename] = []
@ -110,12 +117,28 @@ class NginxParser(object):
for i, (server, path) in enumerate(servers[filename]):
new_server = self._get_included_directives(server)
servers[filename][i] = (new_server, path)
return servers
def get_vhosts(self):
# pylint: disable=cell-var-from-loop
"""Gets list of all 'virtual hosts' found in Nginx configuration.
Technically this is a misnomer because Nginx does not have virtual
hosts, it has 'server blocks'.
:returns: List of :class:`~certbot_nginx.obj.VirtualHost`
objects found in configuration
:rtype: list
"""
enabled = True # We only look at enabled vhosts for now
servers = self._get_raw_servers()
vhosts = []
for filename in servers:
for server, path in servers[filename]:
# Parse the server block into a VirtualHost object
parsed_server = parse_server(server)
parsed_server = _parse_server_raw(server)
vhost = obj.VirtualHost(filename,
parsed_server['addrs'],
parsed_server['ssl'],
@ -125,8 +148,20 @@ class NginxParser(object):
path)
vhosts.append(vhost)
self._update_vhosts_addrs_ssl(vhosts)
return vhosts
def _update_vhosts_addrs_ssl(self, vhosts):
"""Update a list of raw parsed vhosts to include global address sslishness
"""
addr_to_ssl = self._build_addr_to_ssl()
for vhost in vhosts:
for addr in vhost.addrs:
addr.ssl = addr_to_ssl[addr.normalized_tuple()]
if addr.ssl:
vhost.ssl = True
def _get_included_directives(self, block):
"""Returns array with the "include" directives expanded out by
concatenating the contents of the included file to the block.
@ -241,6 +276,17 @@ class NginxParser(object):
except IOError:
logger.error("Could not open file for writing: %s", filename)
def parse_server(self, server):
"""Parses a list of server directives, accounting for global address sslishness.
:param list server: list of directives in a server block
:rtype: dict
"""
addr_to_ssl = self._build_addr_to_ssl()
parsed_server = _parse_server_raw(server)
_apply_global_addr_ssl(addr_to_ssl, parsed_server)
return parsed_server
def has_ssl_on_directive(self, vhost):
"""Does vhost have ssl on for all ports?
@ -290,7 +336,7 @@ class NginxParser(object):
# update vhost based on new directives
new_server = self._get_included_directives(result)
parsed_server = parse_server(new_server)
parsed_server = self.parse_server(new_server)
vhost.addrs = parsed_server['addrs']
vhost.ssl = parsed_server['ssl']
vhost.names = parsed_server['names']
@ -434,41 +480,6 @@ def _get_servernames(names):
names = re.sub(whitespace_re, ' ', names)
return names.split(' ')
def parse_server(server):
"""Parses a list of server directives.
:param list server: list of directives in a server block
:rtype: dict
"""
parsed_server = {'addrs': set(),
'ssl': False,
'names': set()}
apply_ssl_to_all_addrs = False
for directive in server:
if not directive:
continue
if directive[0] == 'listen':
addr = obj.Addr.fromstring(directive[1])
parsed_server['addrs'].add(addr)
if not parsed_server['ssl'] and addr.ssl:
parsed_server['ssl'] = True
elif directive[0] == 'server_name':
parsed_server['names'].update(
_get_servernames(directive[1]))
elif directive[0] == 'ssl' and directive[1] == 'on':
parsed_server['ssl'] = True
apply_ssl_to_all_addrs = True
if apply_ssl_to_all_addrs:
for addr in parsed_server['addrs']:
addr.ssl = True
return parsed_server
def _add_directives(block, directives, replace):
"""Adds or replaces directives in a config block.
@ -549,3 +560,44 @@ def _add_directive(block, directive, replace):
'tried to insert directive "{0}" but found '
'conflicting "{1}".'.format(directive, block[location]))
def _apply_global_addr_ssl(addr_to_ssl, parsed_server):
"""Apply global sslishness information to the parsed server block
"""
for addr in parsed_server['addrs']:
addr.ssl = addr_to_ssl[addr.normalized_tuple()]
if addr.ssl:
parsed_server['ssl'] = True
def _parse_server_raw(server):
"""Parses a list of server directives.
:param list server: list of directives in a server block
:rtype: dict
"""
parsed_server = {'addrs': set(),
'ssl': False,
'names': set()}
apply_ssl_to_all_addrs = False
for directive in server:
if not directive:
continue
if directive[0] == 'listen':
addr = obj.Addr.fromstring(directive[1])
parsed_server['addrs'].add(addr)
if addr.ssl:
parsed_server['ssl'] = True
elif directive[0] == 'server_name':
parsed_server['names'].update(
_get_servernames(directive[1]))
elif directive[0] == 'ssl' and directive[1] == 'on':
parsed_server['ssl'] = True
apply_ssl_to_all_addrs = True
if apply_ssl_to_all_addrs:
for addr in parsed_server['addrs']:
addr.ssl = True
return parsed_server

View file

@ -40,7 +40,7 @@ class NginxConfiguratorTest(util.NginxTest):
def test_prepare(self):
self.assertEqual((1, 6, 2), self.config.version)
self.assertEqual(7, len(self.config.parser.parsed))
self.assertEqual(8, len(self.config.parser.parsed))
# ensure we successfully parsed a file for ssl_options
self.assertTrue(self.config.parser.loc["ssl_options"])
@ -68,7 +68,8 @@ class NginxConfiguratorTest(util.NginxTest):
names = self.config.get_all_names()
self.assertEqual(names, set(
["155.225.50.69.nephoscale.net", "www.example.org", "another.alias",
"migration.com", "summer.com", "geese.com", "sslon.com"]))
"migration.com", "summer.com", "geese.com", "sslon.com",
"globalssl.com", "globalsslsetssl.com"]))
def test_supported_enhancements(self):
self.assertEqual(['redirect', 'staple-ocsp'],

View file

@ -49,7 +49,8 @@ class NginxParserTest(util.NginxTest):
'sites-enabled/default',
'sites-enabled/example.com',
'sites-enabled/migration.com',
'sites-enabled/sslon.com']]),
'sites-enabled/sslon.com',
'sites-enabled/globalssl.com']]),
set(nparser.parsed.keys()))
self.assertEqual([['server_name', 'somename alias another.alias']],
nparser.parsed[nparser.abs_path('server.conf')])
@ -73,7 +74,7 @@ class NginxParserTest(util.NginxTest):
parsed = nparser._parse_files(nparser.abs_path(
'sites-enabled/example.com.test'))
self.assertEqual(3, len(glob.glob(nparser.abs_path('*.test'))))
self.assertEqual(4, len(
self.assertEqual(5, len(
glob.glob(nparser.abs_path('sites-enabled/*.test'))))
self.assertEqual([[['server'], [['listen', '69.50.225.155:9000'],
['listen', '127.0.0.1'],
@ -104,6 +105,16 @@ class NginxParserTest(util.NginxTest):
lambda x, y, pts=paths: pts.append(y))
self.assertEqual(paths, result)
def test_get_vhosts_global_ssl(self):
nparser = parser.NginxParser(self.config_path, self.ssl_options)
vhosts = nparser.get_vhosts()
vhost = obj.VirtualHost(nparser.abs_path('sites-enabled/globalssl.com'),
[obj.Addr('4.8.2.6', '57', True, False)],
True, True, set(['globalssl.com']), [], [0])
globalssl_com = [x for x in vhosts if 'globalssl.com' in x.filep][0]
self.assertEqual(vhost, globalssl_com)
def test_get_vhosts(self):
nparser = parser.NginxParser(self.config_path, self.ssl_options)
@ -137,7 +148,7 @@ class NginxParserTest(util.NginxTest):
'*.www.example.com']),
[], [2, 1, 0])
self.assertEqual(8, len(vhosts))
self.assertEqual(10, len(vhosts))
example_com = [x for x in vhosts if 'example.com' in x.filep][0]
self.assertEqual(vhost3, example_com)
default = [x for x in vhosts if 'default' in x.filep][0]
@ -291,27 +302,34 @@ class NginxParserTest(util.NginxTest):
COMMENT_BLOCK,
["\n", "e", " ", "f"]])
def test_parse_server_ssl(self):
server = parser.parse_server([
def test_parse_server_raw_ssl(self):
server = parser._parse_server_raw([ #pylint: disable=protected-access
['listen', '443']
])
self.assertFalse(server['ssl'])
server = parser.parse_server([
server = parser._parse_server_raw([ #pylint: disable=protected-access
['listen', '443 ssl']
])
self.assertTrue(server['ssl'])
server = parser.parse_server([
server = parser._parse_server_raw([ #pylint: disable=protected-access
['listen', '443'], ['ssl', 'off']
])
self.assertFalse(server['ssl'])
server = parser.parse_server([
server = parser._parse_server_raw([ #pylint: disable=protected-access
['listen', '443'], ['ssl', 'on']
])
self.assertTrue(server['ssl'])
def test_parse_server_global_ssl_applied(self):
nparser = parser.NginxParser(self.config_path, self.ssl_options)
server = nparser.parse_server([
['listen', '443']
])
self.assertTrue(server['ssl'])
def test_ssl_options_should_be_parsed_ssl_directives(self):
nparser = parser.NginxParser(self.config_path, self.ssl_options)
self.assertEqual(nginxparser.UnspacedList(nparser.loc["ssl_options"]),

View file

@ -0,0 +1,9 @@
server {
server_name globalssl.com;
listen 4.8.2.6:57;
}
server {
server_name globalsslsetssl.com;
listen 4.8.2.6:57 ssl;
}

View file

@ -127,17 +127,18 @@ class Addr(object):
return "%s:%s" % self.tup
return self.tup[0]
def normalized_tuple(self):
"""Normalized representation of addr/port tuple
"""
if self.ipv6:
return (self._normalize_ipv6(self.tup[0]), self.tup[1])
return self.tup
def __eq__(self, other):
if isinstance(other, self.__class__):
if self.ipv6:
# compare normalized to take different
# styles of representation into account
return (other.ipv6 and
self._normalize_ipv6(self.tup[0]) ==
self._normalize_ipv6(other.tup[0]) and
self.tup[1] == other.tup[1])
else:
return self.tup == other.tup
# compare normalized to take different
# styles of representation into account
return self.normalized_tuple() == other.normalized_tuple()
return False