Add get_all_certs_keys method to parser

This commit is contained in:
yan 2015-04-14 16:24:10 -07:00
parent fe1ba9dad6
commit d2588de4fd
5 changed files with 64 additions and 33 deletions

View file

@ -54,7 +54,7 @@ class NginxConfigurator(object):
def __init__(self, config, version=None):
"""Initialize an Nginx Configurator.
:param tup version: version of Nginx as a tuple (2, 4, 7)
:param tup version: version of Nginx as a tuple (1, 4, 7)
(used mostly for unittesting)
"""
@ -133,6 +133,7 @@ class NginxConfigurator(object):
", ".join(str(addr) for addr in vhost.addrs)))
self.save_notes += "\tssl_certificate %s\n" % cert
self.save_notes += "\tssl_certificate_key %s\n" % key
self.save()
#######################
# Vhost parsing methods
@ -272,23 +273,14 @@ class NginxConfigurator(object):
def get_all_certs_keys(self):
"""Find all existing keys, certs from configuration.
Retrieve all certs and keys set in VirtualHosts on the Nginx server
:returns: list of tuples with form [(cert, key, path)]
cert - str path to certificate file
key - str path to associated key file
path - File path to configuration file.
:rtype: list
:rtype: set
"""
c_k = set()
for vhost in self.vhosts:
if vhost.ssl:
# TODO: get the cert, key, and conf file paths
pass
return c_k
return self.parser.get_all_certs_keys()
##################################
# enhancement methods (IInstaller)
@ -453,6 +445,8 @@ class NginxConfigurator(object):
if title and not temporary:
self.reverter.finalize_checkpoint(title)
self.vhosts = self.parser.get_vhosts()
return True
def recovery_routine(self):

View file

@ -99,13 +99,14 @@ class VirtualHost(object): # pylint: disable=too-few-public-methods
:ivar set addrs: Virtual Host addresses (:class:`set` of :class:`Addr`)
:ivar set names: Server names/aliases of vhost
(:class:`list` of :class:`str`)
:ivar array raw: The raw form of the parsed server block
:ivar bool ssl: SSLEngine on in vhost
:ivar bool enabled: Virtual host is enabled
"""
def __init__(self, filep, addrs, ssl, enabled, names):
def __init__(self, filep, addrs, ssl, enabled, names, raw):
# pylint: disable=too-many-arguments
"""Initialize a VH."""
self.filep = filep
@ -113,6 +114,7 @@ class VirtualHost(object): # pylint: disable=too-few-public-methods
self.names = names
self.ssl = ssl
self.enabled = enabled
self.raw = raw
def __str__(self):
addr_str = ", ".join(str(addr) for addr in self.addrs)

View file

@ -98,7 +98,7 @@ class NginxParser(object):
"""
enabled = True # We only look at enabled vhosts for now
vhosts = []
servers = {} # Map of filename to list of parsed server blocks
servers = {}
for filename in self.parsed:
tree = self.parsed[filename]
@ -128,7 +128,8 @@ class NginxParser(object):
parsed_server['addrs'],
parsed_server['ssl'],
enabled,
parsed_server['names'])
parsed_server['names'],
server)
vhosts.append(vhost)
return vhosts
@ -319,6 +320,30 @@ class NginxParser(object):
lambda x: self._has_server_names(x, names),
lambda x: x.extend(directives))
def get_all_certs_keys(self):
"""Gets all certs and keys in the nginx config.
:returns: list of tuples with form [(cert, key, path)]
cert - str path to certificate file
key - str path to associated key file
path - File path to configuration file.
:rtype: set
"""
c_k = set()
vhosts = self.get_vhosts()
for vhost in vhosts:
tup = [None, None, vhost.filep]
if vhost.ssl:
for directive in vhost.raw:
if directive[0] == 'ssl_certificate':
tup[0] = directive[1]
elif directive[0] == 'ssl_certificate_key':
tup[1] = directive[1]
if tup[0] is not None and tup[1] is not None:
c_k.add(tuple(tup))
return c_k
def _do_for_subarray(entry, condition, func):
"""Executes a function for a subarray of a nested array if it matches

View file

@ -3,14 +3,9 @@ import glob
import os
import re
import shutil
import sys
import unittest
import zope.component
from letsencrypt.client.display import util as display_util
from letsencrypt.client.errors import LetsEncryptMisconfigurationError
from letsencrypt.client.plugins.nginx.nginxparser import dumps
from letsencrypt.client.plugins.nginx.obj import Addr, VirtualHost
from letsencrypt.client.plugins.nginx.parser import NginxParser, get_best_match
@ -23,9 +18,6 @@ class NginxParserTest(util.NginxTest):
def setUp(self):
super(NginxParserTest, self).setUp()
self.maxDiff = None
zope.component.provideUtility(display_util.FileDisplay(sys.stdout))
def tearDown(self):
shutil.rmtree(self.temp_dir)
shutil.rmtree(self.config_dir)
@ -57,7 +49,8 @@ class NginxParserTest(util.NginxTest):
set(parser.parsed.keys()))
self.assertEqual([['server_name', 'somename alias another.alias']],
parser.parsed[parser.abs_path('server.conf')])
self.assertEqual([[['server'], [['listen', '9000'],
self.assertEqual([[['server'], [['listen', '69.50.225.155:9000'],
['listen', '127.0.0.1'],
['server_name', '.example.com'],
['server_name', 'example.*']]]],
parser.parsed[parser.abs_path(
@ -78,7 +71,8 @@ class NginxParserTest(util.NginxTest):
self.assertEqual(3, len(glob.glob(parser.abs_path('*.test'))))
self.assertEqual(2, len(
glob.glob(parser.abs_path('sites-enabled/*.test'))))
self.assertEqual([[['server'], [['listen', '9000'],
self.assertEqual([[['server'], [['listen', '69.50.225.155:9000'],
['listen', '127.0.0.1'],
['server_name', '.example.com'],
['server_name', 'example.*']]]],
parsed[0])
@ -89,21 +83,23 @@ class NginxParserTest(util.NginxTest):
vhost1 = VirtualHost(parser.abs_path('nginx.conf'),
[Addr('', '8080', False, False)],
False, True, set(['localhost']))
False, True, set(['localhost']), [])
vhost2 = VirtualHost(parser.abs_path('nginx.conf'),
[Addr('somename', '8080', False, False),
Addr('', '8000', False, False)],
False, True, set(['somename',
'another.alias', 'alias']))
'another.alias', 'alias']), [])
vhost3 = VirtualHost(parser.abs_path('sites-enabled/example.com'),
[Addr('', '9000', False, False)],
False, True, set(['.example.com', 'example.*']))
[Addr('69.50.225.155', '9000', False, False),
Addr('127.0.0.1', '', False, False)],
False, True, set(['.example.com', 'example.*']),
[])
vhost4 = VirtualHost(parser.abs_path('sites-enabled/default'),
[Addr('myhost', '', False, True)],
False, True, set(['www.example.org']))
False, True, set(['www.example.org']), [])
vhost5 = VirtualHost(parser.abs_path('foo.conf'),
[Addr('*', '80', True, True)],
True, True, set(['*.www.foo.com']))
True, True, set(['*.www.foo.com']), [])
self.assertEqual(5, len(vhosts))
example_com = filter(lambda x: 'example.com' in x.filep, vhosts)[0]
@ -144,7 +140,9 @@ class NginxParserTest(util.NginxTest):
filep, target, [['server_name', 'foo bar']], True)
self.assertEqual(
parser.parsed[filep],
[[['server'], [['listen', '9000'], ['server_name', 'foo bar'],
[[['server'], [['listen', '69.50.225.155:9000'],
['listen', '127.0.0.1'],
['server_name', 'foo bar'],
['server_name', 'foo bar']]]])
self.assertRaises(LetsEncryptMisconfigurationError,
parser.add_server_directives,
@ -183,6 +181,17 @@ class NginxParserTest(util.NginxTest):
for i, winner in enumerate(winners):
self.assertEqual(winner, get_best_match(target_name, names[i]))
def test_get_all_certs_keys(self):
parser = NginxParser(self.config_path, self.ssl_options)
filep = parser.abs_path('sites-enabled/example.com')
parser.add_server_directives(filep,
set(['.example.com', 'example.*']),
[['ssl_certificate', 'foo.pem'],
['ssl_certificate_key', 'bar.key'],
['listen', '443 ssl']])
ck = parser.get_all_certs_keys()
self.assertEqual(set([('foo.pem', 'bar.key', filep)]), ck)
if __name__ == "__main__":
unittest.main()

View file

@ -1,5 +1,6 @@
server {
listen 9000;
listen 69.50.225.155:9000;
listen 127.0.0.1;
server_name .example.com;
server_name example.*;
}