From 239184882e1694bee62b12e43e3d4e0a1f08b9da Mon Sep 17 00:00:00 2001 From: ohemorange Date: Tue, 6 Jun 2017 17:04:45 -0700 Subject: [PATCH] Enable IPv6 support in standalone plugin (#4773) * add TLSSNI01DualNetworkedServers * use DualNetworkedServers in certbot/plugins/standalone.py also, make both servers run on the same port. * make probe_sni connect on ipv6 and ipv4 using None * mimic BSD-like conditions to get test coverage * test ServerManager taking into account BSD systems * pass tests even if python is compiled without ipv6 support --- acme/acme/crypto_util.py | 9 +- acme/acme/standalone.py | 108 ++++++++++++++++++++-- acme/acme/standalone_test.py | 139 +++++++++++++++++++++++++++++ certbot/plugins/standalone.py | 59 ++++++------ certbot/plugins/standalone_test.py | 11 ++- 5 files changed, 282 insertions(+), 44 deletions(-) diff --git a/acme/acme/crypto_util.py b/acme/acme/crypto_util.py index f86a9971a..84b70e4a6 100644 --- a/acme/acme/crypto_util.py +++ b/acme/acme/crypto_util.py @@ -107,7 +107,7 @@ class SSLSocket(object): # pylint: disable=too-few-public-methods def probe_sni(name, host, port=443, timeout=300, - method=_DEFAULT_TLSSNI01_SSL_METHOD, source_address=('0', 0)): + method=_DEFAULT_TLSSNI01_SSL_METHOD, source_address=('', 0)): """Probe SNI server for SSL certificate. :param bytes name: Byte string to send as the server name in the @@ -132,9 +132,14 @@ def probe_sni(name, host, port=443, timeout=300, socket_kwargs = {} if sys.version_info < (2, 7) else { 'source_address': source_address} + host_protocol_agnostic = None if host == '::' or host == '0' else host + try: # pylint: disable=star-args - sock = socket.create_connection((host, port), **socket_kwargs) + logger.debug("Attempting to connect to %s:%d%s.", host_protocol_agnostic, port, + " from {0}:{1}".format(source_address[0], source_address[1]) if \ + socket_kwargs else "") + sock = socket.create_connection((host_protocol_agnostic, port), **socket_kwargs) except socket.error as error: raise errors.Error(error) diff --git a/acme/acme/standalone.py b/acme/acme/standalone.py index 087240c15..c221f5883 100644 --- a/acme/acme/standalone.py +++ b/acme/acme/standalone.py @@ -4,7 +4,9 @@ import collections import functools import logging import os +import socket import sys +import threading from six.moves import BaseHTTPServer # type: ignore # pylint: disable=import-error from six.moves import http_client # pylint: disable=import-error @@ -26,6 +28,11 @@ class TLSServer(socketserver.TCPServer): """Generic TLS Server.""" def __init__(self, *args, **kwargs): + self.ipv6 = kwargs.pop("ipv6", False) + if self.ipv6: + self.address_family = socket.AF_INET6 + else: + self.address_family = socket.AF_INET self.certs = kwargs.pop("certs", {}) self.method = kwargs.pop( # pylint: disable=protected-access @@ -49,12 +56,81 @@ class ACMEServerMixin: # pylint: disable=old-style-class allow_reuse_address = True +class BaseDualNetworkedServers(object): + """Base class for a pair of IPv6 and IPv4 servers that tries to do everything + it's asked for both servers, but where failures in one server don't + affect the other. + + If two servers are instantiated, they will serve on the same port. + """ + + def __init__(self, ServerClass, server_address, *remaining_args, **kwargs): + port = server_address[1] + self.threads = [] + self.servers = [] + + # Must try True first. + # Ubuntu, for example, will fail to bind to IPv4 if we've already bound + # to IPv6. But that's ok, since it will accept IPv4 connections on the IPv6 + # socket. On the other hand, FreeBSD will successfully bind to IPv4 on the + # same port, which means that server will accept the IPv4 connections. + # If Python is compiled without IPv6, we'll error out but (probably) successfully + # create the IPv4 server. + for ip_version in [True, False]: + try: + kwargs["ipv6"] = ip_version + new_address = (server_address[0],) + (port,) + server_address[2:] + new_args = (new_address,) + remaining_args + server = ServerClass(*new_args, **kwargs) # pylint: disable=star-args + except socket.error: + logger.debug("Failed to bind to %s:%s using %s", new_address[0], + new_address[1], "IPv6" if ip_version else "IPv4") + else: + self.servers.append(server) + # If two servers are set up and port 0 was passed in, ensure we always + # bind to the same port for both servers. + port = server.socket.getsockname()[1] + if len(self.servers) == 0: + raise socket.error("Could not bind to IPv4 or IPv6.") + + def serve_forever(self): + """Wraps socketserver.TCPServer.serve_forever""" + for server in self.servers: + thread = threading.Thread( + # pylint: disable=no-member + target=server.serve_forever) + thread.start() + self.threads.append(thread) + + def getsocknames(self): + """Wraps socketserver.TCPServer.socket.getsockname""" + return [server.socket.getsockname() for server in self.servers] + + def shutdown_and_server_close(self): + """Wraps socketserver.TCPServer.shutdown, socketserver.TCPServer.server_close, and + threading.Thread.join""" + for server in self.servers: + server.shutdown() + server.server_close() + for thread in self.threads: + thread.join() + self.threads = [] + + class TLSSNI01Server(TLSServer, ACMEServerMixin): """TLSSNI01 Server.""" - def __init__(self, server_address, certs): + def __init__(self, server_address, certs, ipv6=False): TLSServer.__init__( - self, server_address, BaseRequestHandlerWithLogging, certs=certs) + self, server_address, BaseRequestHandlerWithLogging, certs=certs, ipv6=ipv6) + + +class TLSSNI01DualNetworkedServers(BaseDualNetworkedServers): + """TLSSNI01Server Wrapper. Tries everything for both. Failures for one don't + affect the other.""" + + def __init__(self, *args, **kwargs): + BaseDualNetworkedServers.__init__(self, TLSSNI01Server, *args, **kwargs) class BaseRequestHandlerWithLogging(socketserver.BaseRequestHandler): @@ -70,13 +146,33 @@ class BaseRequestHandlerWithLogging(socketserver.BaseRequestHandler): socketserver.BaseRequestHandler.handle(self) -class HTTP01Server(BaseHTTPServer.HTTPServer, ACMEServerMixin): +class HTTPServer(BaseHTTPServer.HTTPServer): + """Generic HTTP Server.""" + + def __init__(self, *args, **kwargs): + self.ipv6 = kwargs.pop("ipv6", False) + if self.ipv6: + self.address_family = socket.AF_INET6 + else: + self.address_family = socket.AF_INET + BaseHTTPServer.HTTPServer.__init__(self, *args, **kwargs) + + +class HTTP01Server(HTTPServer, ACMEServerMixin): """HTTP01 Server.""" - def __init__(self, server_address, resources): - BaseHTTPServer.HTTPServer.__init__( + def __init__(self, server_address, resources, ipv6=False): + HTTPServer.__init__( self, server_address, HTTP01RequestHandler.partial_init( - simple_http_resources=resources)) + simple_http_resources=resources), ipv6=ipv6) + + +class HTTP01DualNetworkedServers(BaseDualNetworkedServers): + """HTTP01Server Wrapper. Tries everything for both. Failures for one don't + affect the other.""" + + def __init__(self, *args, **kwargs): + BaseDualNetworkedServers.__init__(self, HTTP01Server, *args, **kwargs) class HTTP01RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler): diff --git a/acme/acme/standalone_test.py b/acme/acme/standalone_test.py index c3beab34b..16669680c 100644 --- a/acme/acme/standalone_test.py +++ b/acme/acme/standalone_test.py @@ -1,6 +1,7 @@ """Tests for acme.standalone.""" import os import shutil +import socket import threading import tempfile import time @@ -9,6 +10,7 @@ import unittest from six.moves import http_client # pylint: disable=import-error from six.moves import socketserver # type: ignore # pylint: disable=import-error +import mock import requests from acme import challenges @@ -29,6 +31,13 @@ class TLSServerTest(unittest.TestCase): ('', 0), socketserver.BaseRequestHandler, bind_and_activate=True) server.server_close() # pylint: disable=no-member + def test_ipv6(self): + if socket.has_ipv6: + from acme.standalone import TLSServer + server = TLSServer( + ('', 0), socketserver.BaseRequestHandler, bind_and_activate=True, ipv6=True) + server.server_close() # pylint: disable=no-member + class TLSSNI01ServerTest(unittest.TestCase): """Test for acme.standalone.TLSSNI01Server.""" @@ -112,6 +121,136 @@ class HTTP01ServerTest(unittest.TestCase): self.assertFalse(self._test_http01(add=False)) +class BaseDualNetworkedServersTest(unittest.TestCase): + """Test for acme.standalone.BaseDualNetworkedServers.""" + + _multiprocess_can_split_ = True + + class SingleProtocolServer(socketserver.TCPServer): + """Server that only serves on a single protocol. FreeBSD has this behavior for AF_INET6.""" + def __init__(self, *args, **kwargs): + ipv6 = kwargs.pop("ipv6", False) + if ipv6: + self.address_family = socket.AF_INET6 + kwargs["bind_and_activate"] = False + else: + self.address_family = socket.AF_INET + socketserver.TCPServer.__init__(self, *args, **kwargs) + if ipv6: + # pylint: disable=no-member + self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) + try: + self.server_bind() + self.server_activate() + except: + self.server_close() + raise + + @mock.patch("socket.socket.bind") + def test_fail_to_bind(self, mock_bind): + mock_bind.side_effect = socket.error + from acme.standalone import BaseDualNetworkedServers + self.assertRaises(socket.error, BaseDualNetworkedServers, + BaseDualNetworkedServersTest.SingleProtocolServer, + ("", 0), + socketserver.BaseRequestHandler) + + def test_ports_equal(self): + from acme.standalone import BaseDualNetworkedServers + servers = BaseDualNetworkedServers( + BaseDualNetworkedServersTest.SingleProtocolServer, + ("", 0), + socketserver.BaseRequestHandler) + socknames = servers.getsocknames() + prev_port = None + # assert ports are equal + for sockname in socknames: + port = sockname[1] + if prev_port: + self.assertEqual(prev_port, port) + prev_port = port + + +class TLSSNI01DualNetworkedServersTest(unittest.TestCase): + """Test for acme.standalone.TLSSNI01DualNetworkedServers.""" + + _multiprocess_can_split_ = True + + def setUp(self): + self.certs = {b'localhost': ( + test_util.load_pyopenssl_private_key('rsa2048_key.pem'), + test_util.load_cert('rsa2048_cert.pem'), + )} + from acme.standalone import TLSSNI01DualNetworkedServers + self.servers = TLSSNI01DualNetworkedServers(("", 0), certs=self.certs) + self.servers.serve_forever() + + def tearDown(self): + self.servers.shutdown_and_server_close() + + def test_connect(self): + socknames = self.servers.getsocknames() + # connect to all addresses + for sockname in socknames: + host, port = sockname[:2] + cert = crypto_util.probe_sni( + b'localhost', host=host, port=port, timeout=1) + self.assertEqual(jose.ComparableX509(cert), + jose.ComparableX509(self.certs[b'localhost'][1])) + + +class HTTP01DualNetworkedServersTest(unittest.TestCase): + """Tests for acme.standalone.HTTP01DualNetworkedServers.""" + + _multiprocess_can_split_ = True + + def setUp(self): + self.account_key = jose.JWK.load( + test_util.load_vector('rsa1024_key.pem')) + self.resources = set() + + from acme.standalone import HTTP01DualNetworkedServers + self.servers = HTTP01DualNetworkedServers(('', 0), resources=self.resources) + + # pylint: disable=no-member + self.port = self.servers.getsocknames()[0][1] + self.servers.serve_forever() + + def tearDown(self): + self.servers.shutdown_and_server_close() + + def test_index(self): + response = requests.get( + 'http://localhost:{0}'.format(self.port), verify=False) + self.assertEqual( + response.text, 'ACME client standalone challenge solver') + self.assertTrue(response.ok) + + def test_404(self): + response = requests.get( + 'http://localhost:{0}/foo'.format(self.port), verify=False) + self.assertEqual(response.status_code, http_client.NOT_FOUND) + + def _test_http01(self, add): + chall = challenges.HTTP01(token=(b'x' * 16)) + response, validation = chall.response_and_validation(self.account_key) + + from acme.standalone import HTTP01RequestHandler + resource = HTTP01RequestHandler.HTTP01Resource( + chall=chall, response=response, validation=validation) + if add: + self.resources.add(resource) + return resource.response.simple_verify( + resource.chall, 'localhost', self.account_key.public_key(), + port=self.port) + + def test_http01_found(self): + self.assertTrue(self._test_http01(add=True)) + + def test_http01_not_found(self): + self.assertFalse(self._test_http01(add=False)) + + class TestSimpleTLSSNI01Server(unittest.TestCase): """Tests for acme.standalone.simple_tls_sni_01_server.""" diff --git a/certbot/plugins/standalone.py b/certbot/plugins/standalone.py index ce878f84a..817403bd3 100644 --- a/certbot/plugins/standalone.py +++ b/certbot/plugins/standalone.py @@ -3,7 +3,6 @@ import argparse import collections import logging import socket -import threading import OpenSSL import six @@ -33,8 +32,6 @@ class ServerManager(object): will serve the same URLs! """ - _Instance = collections.namedtuple("_Instance", "server thread") - def __init__(self, certs, http_01_resources): self._instances = {} self.certs = certs @@ -51,34 +48,32 @@ class ServerManager(object): either `acme.challenge.HTTP01` or `acme.challenges.TLSSNI01`. :param str listenaddr: (optional) The address to listen on. Defaults to all addrs. - :returns: Server instance. + :returns: DualNetworkedServers instance. :rtype: ACMEServerMixin """ assert challenge_type in (challenges.TLSSNI01, challenges.HTTP01) if port in self._instances: - return self._instances[port].server + return self._instances[port] address = (listenaddr, port) try: if challenge_type is challenges.TLSSNI01: - server = acme_standalone.TLSSNI01Server(address, self.certs) + servers = acme_standalone.TLSSNI01DualNetworkedServers(address, self.certs) else: # challenges.HTTP01 - server = acme_standalone.HTTP01Server( + servers = acme_standalone.HTTP01DualNetworkedServers( address, self.http_01_resources) except socket.error as error: raise errors.StandaloneBindError(error, port) - thread = threading.Thread( - # pylint: disable=no-member - target=server.serve_forever) - thread.start() + servers.serve_forever() # if port == 0, then random free port on OS is taken # pylint: disable=no-member - real_port = server.socket.getsockname()[1] - self._instances[real_port] = self._Instance(server, thread) - return server + # both servers, if they exist, have the same port + real_port = servers.getsocknames()[0][1] + self._instances[real_port] = servers + return servers def stop(self, port): """Stop ACME server running on the specified ``port``. @@ -87,13 +82,12 @@ class ServerManager(object): """ instance = self._instances[port] - logger.debug("Stopping server at %s:%d...", - *instance.server.socket.getsockname()[:2]) - instance.server.shutdown() + for sockname in instance.getsocknames(): + logger.debug("Stopping server at %s:%d...", + *sockname[:2]) # Not calling server_close causes problems when renewing multiple # certs with `certbot renew` using TLSSNI01 and PyOpenSSL 0.13 - instance.server.server_close() - instance.thread.join() + instance.shutdown_and_server_close() del self._instances[port] def running(self): @@ -102,12 +96,11 @@ class ServerManager(object): Once the server is stopped using `stop`, it will not be returned. - :returns: Mapping from ``port`` to ``server``. + :returns: Mapping from ``port`` to ``servers``. :rtype: tuple """ - return dict((port, instance.server) for port, instance - in six.iteritems(self._instances)) + return self._instances.copy() SUPPORTED_CHALLENGES = [challenges.TLSSNI01, challenges.HTTP01] @@ -236,38 +229,38 @@ class Authenticator(common.Plugin): def _perform_single(self, achall): if isinstance(achall.chall, challenges.HTTP01): - server, response = self._perform_http_01(achall) + servers, response = self._perform_http_01(achall) else: # tls-sni-01 - server, response = self._perform_tls_sni_01(achall) - self.served[server].add(achall) + servers, response = self._perform_tls_sni_01(achall) + self.served[servers].add(achall) return response def _perform_http_01(self, achall): port = self.config.http01_port addr = self.config.http01_address - server = self.servers.run(port, challenges.HTTP01, listenaddr=addr) + servers = self.servers.run(port, challenges.HTTP01, listenaddr=addr) response, validation = achall.response_and_validation() resource = acme_standalone.HTTP01RequestHandler.HTTP01Resource( chall=achall.chall, response=response, validation=validation) self.http_01_resources.add(resource) - return server, response + return servers, response def _perform_tls_sni_01(self, achall): port = self.config.tls_sni_01_port addr = self.config.tls_sni_01_address - server = self.servers.run(port, challenges.TLSSNI01, listenaddr=addr) + servers = self.servers.run(port, challenges.TLSSNI01, listenaddr=addr) response, (cert, _) = achall.response_and_validation(cert_key=self.key) self.certs[response.z_domain] = (self.key, cert) - return server, response + return servers, response def cleanup(self, achalls): # pylint: disable=missing-docstring - # reduce self.served and close servers if none challenges are served - for server, server_achalls in self.served.items(): + # reduce self.served and close servers if no challenges are served + for unused_servers, server_achalls in self.served.items(): for achall in achalls: if achall in server_achalls: server_achalls.remove(achall) - for port, server in six.iteritems(self.servers.running()): - if not self.served[server]: + for port, servers in six.iteritems(self.servers.running()): + if not self.served[servers]: self.servers.stop(port) diff --git a/certbot/plugins/standalone_test.py b/certbot/plugins/standalone_test.py index 65d16c2f2..2a55c516f 100644 --- a/certbot/plugins/standalone_test.py +++ b/certbot/plugins/standalone_test.py @@ -32,7 +32,7 @@ class ServerManagerTest(unittest.TestCase): def _test_run_stop(self, challenge_type): server = self.mgr.run(port=0, challenge_type=challenge_type) - port = server.socket.getsockname()[1] # pylint: disable=no-member + port = server.getsocknames()[0][1] # pylint: disable=no-member self.assertEqual(self.mgr.running(), {port: server}) self.mgr.stop(port=port) self.assertEqual(self.mgr.running(), {}) @@ -45,7 +45,7 @@ class ServerManagerTest(unittest.TestCase): def test_run_idempotent(self): server = self.mgr.run(port=0, challenge_type=challenges.HTTP01) - port = server.socket.getsockname()[1] # pylint: disable=no-member + port = server.getsocknames()[0][1] # pylint: disable=no-member server2 = self.mgr.run(port=port, challenge_type=challenges.HTTP01) self.assertEqual(self.mgr.running(), {port: server}) self.assertTrue(server is server2) @@ -53,9 +53,14 @@ class ServerManagerTest(unittest.TestCase): self.assertEqual(self.mgr.running(), {}) def test_run_bind_error(self): - some_server = socket.socket() + some_server = socket.socket(socket.AF_INET6) some_server.bind(("", 0)) port = some_server.getsockname()[1] + maybe_another_server = socket.socket() + try: + maybe_another_server.bind(("", port)) + except socket.error: + pass self.assertRaises( errors.StandaloneBindError, self.mgr.run, port, challenge_type=challenges.HTTP01)