diff --git a/certbot-nginx/certbot_nginx/obj.py b/certbot-nginx/certbot_nginx/obj.py index 4a3ca865e..98bf86f5c 100644 --- a/certbot-nginx/certbot_nginx/obj.py +++ b/certbot-nginx/certbot_nginx/obj.py @@ -93,9 +93,16 @@ class Addr(common.Addr): def __repr__(self): return "Addr(" + self.__str__() + ")" + def super_eq(self, other): + """Check ip/port equality, with IPv6 support. + """ + # Nginx plugin currently doesn't support IPv6 but this will + # future-proof it + return super(Addr, self).__eq__(other) + def __eq__(self, other): if isinstance(other, self.__class__): - return (self.tup == other.tup and + return (self.super_eq(other) and self.ssl == other.ssl and self.default == other.default) return False diff --git a/certbot-nginx/certbot_nginx/parser.py b/certbot-nginx/certbot_nginx/parser.py index 385635212..1a2c85c2c 100644 --- a/certbot-nginx/certbot_nginx/parser.py +++ b/certbot-nginx/certbot_nginx/parser.py @@ -82,21 +82,28 @@ class NginxParser(object): else: return path - def get_vhosts(self): - # pylint: disable=cell-var-from-loop - """Gets list of all 'virtual hosts' found in Nginx configuration. - Technically this is a misnomer because Nginx does not have virtual - hosts, it has 'server blocks'. - - :returns: List of :class:`~certbot_nginx.obj.VirtualHost` - objects found in configuration - :rtype: list - + def _build_addr_to_ssl(self): + """Builds a map from address to whether it listens on ssl in any server block """ - enabled = True # We only look at enabled vhosts for now - vhosts = [] - servers = {} + servers = self._get_raw_servers() + addr_to_ssl = {} + for filename in servers: + for server, _ in servers[filename]: + # Parse the server block to save addr info + parsed_server = _parse_server_raw(server) + for addr in parsed_server['addrs']: + addr_tuple = addr.normalized_tuple() + if addr_tuple not in addr_to_ssl: + addr_to_ssl[addr_tuple] = addr.ssl + addr_to_ssl[addr_tuple] = addr.ssl or addr_to_ssl[addr_tuple] + return addr_to_ssl + + def _get_raw_servers(self): + # pylint: disable=cell-var-from-loop + """Get a map of unparsed all server blocks + """ + servers = {} for filename in self.parsed: tree = self.parsed[filename] servers[filename] = [] @@ -110,12 +117,28 @@ class NginxParser(object): for i, (server, path) in enumerate(servers[filename]): new_server = self._get_included_directives(server) servers[filename][i] = (new_server, path) + return servers + def get_vhosts(self): + # pylint: disable=cell-var-from-loop + """Gets list of all 'virtual hosts' found in Nginx configuration. + Technically this is a misnomer because Nginx does not have virtual + hosts, it has 'server blocks'. + + :returns: List of :class:`~certbot_nginx.obj.VirtualHost` + objects found in configuration + :rtype: list + + """ + enabled = True # We only look at enabled vhosts for now + servers = self._get_raw_servers() + + vhosts = [] for filename in servers: for server, path in servers[filename]: # Parse the server block into a VirtualHost object - parsed_server = parse_server(server) + parsed_server = _parse_server_raw(server) vhost = obj.VirtualHost(filename, parsed_server['addrs'], parsed_server['ssl'], @@ -125,8 +148,20 @@ class NginxParser(object): path) vhosts.append(vhost) + self._update_vhosts_addrs_ssl(vhosts) + return vhosts + def _update_vhosts_addrs_ssl(self, vhosts): + """Update a list of raw parsed vhosts to include global address sslishness + """ + addr_to_ssl = self._build_addr_to_ssl() + for vhost in vhosts: + for addr in vhost.addrs: + addr.ssl = addr_to_ssl[addr.normalized_tuple()] + if addr.ssl: + vhost.ssl = True + def _get_included_directives(self, block): """Returns array with the "include" directives expanded out by concatenating the contents of the included file to the block. @@ -241,6 +276,17 @@ class NginxParser(object): except IOError: logger.error("Could not open file for writing: %s", filename) + def parse_server(self, server): + """Parses a list of server directives, accounting for global address sslishness. + + :param list server: list of directives in a server block + :rtype: dict + """ + addr_to_ssl = self._build_addr_to_ssl() + parsed_server = _parse_server_raw(server) + _apply_global_addr_ssl(addr_to_ssl, parsed_server) + return parsed_server + def has_ssl_on_directive(self, vhost): """Does vhost have ssl on for all ports? @@ -290,7 +336,7 @@ class NginxParser(object): # update vhost based on new directives new_server = self._get_included_directives(result) - parsed_server = parse_server(new_server) + parsed_server = self.parse_server(new_server) vhost.addrs = parsed_server['addrs'] vhost.ssl = parsed_server['ssl'] vhost.names = parsed_server['names'] @@ -434,41 +480,6 @@ def _get_servernames(names): names = re.sub(whitespace_re, ' ', names) return names.split(' ') - -def parse_server(server): - """Parses a list of server directives. - - :param list server: list of directives in a server block - :rtype: dict - - """ - parsed_server = {'addrs': set(), - 'ssl': False, - 'names': set()} - - apply_ssl_to_all_addrs = False - - for directive in server: - if not directive: - continue - if directive[0] == 'listen': - addr = obj.Addr.fromstring(directive[1]) - parsed_server['addrs'].add(addr) - if not parsed_server['ssl'] and addr.ssl: - parsed_server['ssl'] = True - elif directive[0] == 'server_name': - parsed_server['names'].update( - _get_servernames(directive[1])) - elif directive[0] == 'ssl' and directive[1] == 'on': - parsed_server['ssl'] = True - apply_ssl_to_all_addrs = True - - if apply_ssl_to_all_addrs: - for addr in parsed_server['addrs']: - addr.ssl = True - - return parsed_server - def _add_directives(block, directives, replace): """Adds or replaces directives in a config block. @@ -549,3 +560,44 @@ def _add_directive(block, directive, replace): 'tried to insert directive "{0}" but found ' 'conflicting "{1}".'.format(directive, block[location])) +def _apply_global_addr_ssl(addr_to_ssl, parsed_server): + """Apply global sslishness information to the parsed server block + """ + for addr in parsed_server['addrs']: + addr.ssl = addr_to_ssl[addr.normalized_tuple()] + if addr.ssl: + parsed_server['ssl'] = True + +def _parse_server_raw(server): + """Parses a list of server directives. + + :param list server: list of directives in a server block + :rtype: dict + + """ + parsed_server = {'addrs': set(), + 'ssl': False, + 'names': set()} + + apply_ssl_to_all_addrs = False + + for directive in server: + if not directive: + continue + if directive[0] == 'listen': + addr = obj.Addr.fromstring(directive[1]) + parsed_server['addrs'].add(addr) + if addr.ssl: + parsed_server['ssl'] = True + elif directive[0] == 'server_name': + parsed_server['names'].update( + _get_servernames(directive[1])) + elif directive[0] == 'ssl' and directive[1] == 'on': + parsed_server['ssl'] = True + apply_ssl_to_all_addrs = True + + if apply_ssl_to_all_addrs: + for addr in parsed_server['addrs']: + addr.ssl = True + + return parsed_server diff --git a/certbot-nginx/certbot_nginx/tests/configurator_test.py b/certbot-nginx/certbot_nginx/tests/configurator_test.py index f165ea23a..08a66fc98 100644 --- a/certbot-nginx/certbot_nginx/tests/configurator_test.py +++ b/certbot-nginx/certbot_nginx/tests/configurator_test.py @@ -40,7 +40,7 @@ class NginxConfiguratorTest(util.NginxTest): def test_prepare(self): self.assertEqual((1, 6, 2), self.config.version) - self.assertEqual(7, len(self.config.parser.parsed)) + self.assertEqual(8, len(self.config.parser.parsed)) # ensure we successfully parsed a file for ssl_options self.assertTrue(self.config.parser.loc["ssl_options"]) @@ -68,7 +68,8 @@ class NginxConfiguratorTest(util.NginxTest): names = self.config.get_all_names() self.assertEqual(names, set( ["155.225.50.69.nephoscale.net", "www.example.org", "another.alias", - "migration.com", "summer.com", "geese.com", "sslon.com"])) + "migration.com", "summer.com", "geese.com", "sslon.com", + "globalssl.com", "globalsslsetssl.com"])) def test_supported_enhancements(self): self.assertEqual(['redirect', 'staple-ocsp'], diff --git a/certbot-nginx/certbot_nginx/tests/parser_test.py b/certbot-nginx/certbot_nginx/tests/parser_test.py index 54deffd7a..921cc3c5a 100644 --- a/certbot-nginx/certbot_nginx/tests/parser_test.py +++ b/certbot-nginx/certbot_nginx/tests/parser_test.py @@ -49,7 +49,8 @@ class NginxParserTest(util.NginxTest): 'sites-enabled/default', 'sites-enabled/example.com', 'sites-enabled/migration.com', - 'sites-enabled/sslon.com']]), + 'sites-enabled/sslon.com', + 'sites-enabled/globalssl.com']]), set(nparser.parsed.keys())) self.assertEqual([['server_name', 'somename alias another.alias']], nparser.parsed[nparser.abs_path('server.conf')]) @@ -73,7 +74,7 @@ class NginxParserTest(util.NginxTest): parsed = nparser._parse_files(nparser.abs_path( 'sites-enabled/example.com.test')) self.assertEqual(3, len(glob.glob(nparser.abs_path('*.test')))) - self.assertEqual(4, len( + self.assertEqual(5, len( glob.glob(nparser.abs_path('sites-enabled/*.test')))) self.assertEqual([[['server'], [['listen', '69.50.225.155:9000'], ['listen', '127.0.0.1'], @@ -104,6 +105,16 @@ class NginxParserTest(util.NginxTest): lambda x, y, pts=paths: pts.append(y)) self.assertEqual(paths, result) + def test_get_vhosts_global_ssl(self): + nparser = parser.NginxParser(self.config_path, self.ssl_options) + vhosts = nparser.get_vhosts() + + vhost = obj.VirtualHost(nparser.abs_path('sites-enabled/globalssl.com'), + [obj.Addr('4.8.2.6', '57', True, False)], + True, True, set(['globalssl.com']), [], [0]) + + globalssl_com = [x for x in vhosts if 'globalssl.com' in x.filep][0] + self.assertEqual(vhost, globalssl_com) def test_get_vhosts(self): nparser = parser.NginxParser(self.config_path, self.ssl_options) @@ -137,7 +148,7 @@ class NginxParserTest(util.NginxTest): '*.www.example.com']), [], [2, 1, 0]) - self.assertEqual(8, len(vhosts)) + self.assertEqual(10, len(vhosts)) example_com = [x for x in vhosts if 'example.com' in x.filep][0] self.assertEqual(vhost3, example_com) default = [x for x in vhosts if 'default' in x.filep][0] @@ -291,27 +302,34 @@ class NginxParserTest(util.NginxTest): COMMENT_BLOCK, ["\n", "e", " ", "f"]]) - def test_parse_server_ssl(self): - server = parser.parse_server([ + def test_parse_server_raw_ssl(self): + server = parser._parse_server_raw([ #pylint: disable=protected-access ['listen', '443'] ]) self.assertFalse(server['ssl']) - server = parser.parse_server([ + server = parser._parse_server_raw([ #pylint: disable=protected-access ['listen', '443 ssl'] ]) self.assertTrue(server['ssl']) - server = parser.parse_server([ + server = parser._parse_server_raw([ #pylint: disable=protected-access ['listen', '443'], ['ssl', 'off'] ]) self.assertFalse(server['ssl']) - server = parser.parse_server([ + server = parser._parse_server_raw([ #pylint: disable=protected-access ['listen', '443'], ['ssl', 'on'] ]) self.assertTrue(server['ssl']) + def test_parse_server_global_ssl_applied(self): + nparser = parser.NginxParser(self.config_path, self.ssl_options) + server = nparser.parse_server([ + ['listen', '443'] + ]) + self.assertTrue(server['ssl']) + def test_ssl_options_should_be_parsed_ssl_directives(self): nparser = parser.NginxParser(self.config_path, self.ssl_options) self.assertEqual(nginxparser.UnspacedList(nparser.loc["ssl_options"]), diff --git a/certbot-nginx/certbot_nginx/tests/testdata/etc_nginx/sites-enabled/globalssl.com b/certbot-nginx/certbot_nginx/tests/testdata/etc_nginx/sites-enabled/globalssl.com new file mode 100644 index 000000000..969447d6e --- /dev/null +++ b/certbot-nginx/certbot_nginx/tests/testdata/etc_nginx/sites-enabled/globalssl.com @@ -0,0 +1,9 @@ +server { + server_name globalssl.com; + listen 4.8.2.6:57; +} + +server { + server_name globalsslsetssl.com; + listen 4.8.2.6:57 ssl; +} diff --git a/certbot/plugins/common.py b/certbot/plugins/common.py index 007105c7b..46d4c5740 100644 --- a/certbot/plugins/common.py +++ b/certbot/plugins/common.py @@ -127,17 +127,18 @@ class Addr(object): return "%s:%s" % self.tup return self.tup[0] + def normalized_tuple(self): + """Normalized representation of addr/port tuple + """ + if self.ipv6: + return (self._normalize_ipv6(self.tup[0]), self.tup[1]) + return self.tup + def __eq__(self, other): if isinstance(other, self.__class__): - if self.ipv6: - # compare normalized to take different - # styles of representation into account - return (other.ipv6 and - self._normalize_ipv6(self.tup[0]) == - self._normalize_ipv6(other.tup[0]) and - self.tup[1] == other.tup[1]) - else: - return self.tup == other.tup + # compare normalized to take different + # styles of representation into account + return self.normalized_tuple() == other.normalized_tuple() return False