Enable IPv6 support in standalone plugin (#4773)

* add TLSSNI01DualNetworkedServers

* use DualNetworkedServers in certbot/plugins/standalone.py
  also, make both servers run on the same port.

* make probe_sni connect on ipv6 and ipv4 using None

* mimic BSD-like conditions to get test coverage

* test ServerManager taking into account BSD systems

* pass tests even if python is compiled without ipv6 support
This commit is contained in:
ohemorange 2017-06-06 17:04:45 -07:00 committed by GitHub
parent af8dae6cb2
commit 239184882e
5 changed files with 282 additions and 44 deletions

View file

@ -107,7 +107,7 @@ class SSLSocket(object): # pylint: disable=too-few-public-methods
def probe_sni(name, host, port=443, timeout=300,
method=_DEFAULT_TLSSNI01_SSL_METHOD, source_address=('0', 0)):
method=_DEFAULT_TLSSNI01_SSL_METHOD, source_address=('', 0)):
"""Probe SNI server for SSL certificate.
:param bytes name: Byte string to send as the server name in the
@ -132,9 +132,14 @@ def probe_sni(name, host, port=443, timeout=300,
socket_kwargs = {} if sys.version_info < (2, 7) else {
'source_address': source_address}
host_protocol_agnostic = None if host == '::' or host == '0' else host
try:
# pylint: disable=star-args
sock = socket.create_connection((host, port), **socket_kwargs)
logger.debug("Attempting to connect to %s:%d%s.", host_protocol_agnostic, port,
" from {0}:{1}".format(source_address[0], source_address[1]) if \
socket_kwargs else "")
sock = socket.create_connection((host_protocol_agnostic, port), **socket_kwargs)
except socket.error as error:
raise errors.Error(error)

View file

@ -4,7 +4,9 @@ import collections
import functools
import logging
import os
import socket
import sys
import threading
from six.moves import BaseHTTPServer # type: ignore # pylint: disable=import-error
from six.moves import http_client # pylint: disable=import-error
@ -26,6 +28,11 @@ class TLSServer(socketserver.TCPServer):
"""Generic TLS Server."""
def __init__(self, *args, **kwargs):
self.ipv6 = kwargs.pop("ipv6", False)
if self.ipv6:
self.address_family = socket.AF_INET6
else:
self.address_family = socket.AF_INET
self.certs = kwargs.pop("certs", {})
self.method = kwargs.pop(
# pylint: disable=protected-access
@ -49,12 +56,81 @@ class ACMEServerMixin: # pylint: disable=old-style-class
allow_reuse_address = True
class BaseDualNetworkedServers(object):
"""Base class for a pair of IPv6 and IPv4 servers that tries to do everything
it's asked for both servers, but where failures in one server don't
affect the other.
If two servers are instantiated, they will serve on the same port.
"""
def __init__(self, ServerClass, server_address, *remaining_args, **kwargs):
port = server_address[1]
self.threads = []
self.servers = []
# Must try True first.
# Ubuntu, for example, will fail to bind to IPv4 if we've already bound
# to IPv6. But that's ok, since it will accept IPv4 connections on the IPv6
# socket. On the other hand, FreeBSD will successfully bind to IPv4 on the
# same port, which means that server will accept the IPv4 connections.
# If Python is compiled without IPv6, we'll error out but (probably) successfully
# create the IPv4 server.
for ip_version in [True, False]:
try:
kwargs["ipv6"] = ip_version
new_address = (server_address[0],) + (port,) + server_address[2:]
new_args = (new_address,) + remaining_args
server = ServerClass(*new_args, **kwargs) # pylint: disable=star-args
except socket.error:
logger.debug("Failed to bind to %s:%s using %s", new_address[0],
new_address[1], "IPv6" if ip_version else "IPv4")
else:
self.servers.append(server)
# If two servers are set up and port 0 was passed in, ensure we always
# bind to the same port for both servers.
port = server.socket.getsockname()[1]
if len(self.servers) == 0:
raise socket.error("Could not bind to IPv4 or IPv6.")
def serve_forever(self):
"""Wraps socketserver.TCPServer.serve_forever"""
for server in self.servers:
thread = threading.Thread(
# pylint: disable=no-member
target=server.serve_forever)
thread.start()
self.threads.append(thread)
def getsocknames(self):
"""Wraps socketserver.TCPServer.socket.getsockname"""
return [server.socket.getsockname() for server in self.servers]
def shutdown_and_server_close(self):
"""Wraps socketserver.TCPServer.shutdown, socketserver.TCPServer.server_close, and
threading.Thread.join"""
for server in self.servers:
server.shutdown()
server.server_close()
for thread in self.threads:
thread.join()
self.threads = []
class TLSSNI01Server(TLSServer, ACMEServerMixin):
"""TLSSNI01 Server."""
def __init__(self, server_address, certs):
def __init__(self, server_address, certs, ipv6=False):
TLSServer.__init__(
self, server_address, BaseRequestHandlerWithLogging, certs=certs)
self, server_address, BaseRequestHandlerWithLogging, certs=certs, ipv6=ipv6)
class TLSSNI01DualNetworkedServers(BaseDualNetworkedServers):
"""TLSSNI01Server Wrapper. Tries everything for both. Failures for one don't
affect the other."""
def __init__(self, *args, **kwargs):
BaseDualNetworkedServers.__init__(self, TLSSNI01Server, *args, **kwargs)
class BaseRequestHandlerWithLogging(socketserver.BaseRequestHandler):
@ -70,13 +146,33 @@ class BaseRequestHandlerWithLogging(socketserver.BaseRequestHandler):
socketserver.BaseRequestHandler.handle(self)
class HTTP01Server(BaseHTTPServer.HTTPServer, ACMEServerMixin):
class HTTPServer(BaseHTTPServer.HTTPServer):
"""Generic HTTP Server."""
def __init__(self, *args, **kwargs):
self.ipv6 = kwargs.pop("ipv6", False)
if self.ipv6:
self.address_family = socket.AF_INET6
else:
self.address_family = socket.AF_INET
BaseHTTPServer.HTTPServer.__init__(self, *args, **kwargs)
class HTTP01Server(HTTPServer, ACMEServerMixin):
"""HTTP01 Server."""
def __init__(self, server_address, resources):
BaseHTTPServer.HTTPServer.__init__(
def __init__(self, server_address, resources, ipv6=False):
HTTPServer.__init__(
self, server_address, HTTP01RequestHandler.partial_init(
simple_http_resources=resources))
simple_http_resources=resources), ipv6=ipv6)
class HTTP01DualNetworkedServers(BaseDualNetworkedServers):
"""HTTP01Server Wrapper. Tries everything for both. Failures for one don't
affect the other."""
def __init__(self, *args, **kwargs):
BaseDualNetworkedServers.__init__(self, HTTP01Server, *args, **kwargs)
class HTTP01RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):

View file

@ -1,6 +1,7 @@
"""Tests for acme.standalone."""
import os
import shutil
import socket
import threading
import tempfile
import time
@ -9,6 +10,7 @@ import unittest
from six.moves import http_client # pylint: disable=import-error
from six.moves import socketserver # type: ignore # pylint: disable=import-error
import mock
import requests
from acme import challenges
@ -29,6 +31,13 @@ class TLSServerTest(unittest.TestCase):
('', 0), socketserver.BaseRequestHandler, bind_and_activate=True)
server.server_close() # pylint: disable=no-member
def test_ipv6(self):
if socket.has_ipv6:
from acme.standalone import TLSServer
server = TLSServer(
('', 0), socketserver.BaseRequestHandler, bind_and_activate=True, ipv6=True)
server.server_close() # pylint: disable=no-member
class TLSSNI01ServerTest(unittest.TestCase):
"""Test for acme.standalone.TLSSNI01Server."""
@ -112,6 +121,136 @@ class HTTP01ServerTest(unittest.TestCase):
self.assertFalse(self._test_http01(add=False))
class BaseDualNetworkedServersTest(unittest.TestCase):
"""Test for acme.standalone.BaseDualNetworkedServers."""
_multiprocess_can_split_ = True
class SingleProtocolServer(socketserver.TCPServer):
"""Server that only serves on a single protocol. FreeBSD has this behavior for AF_INET6."""
def __init__(self, *args, **kwargs):
ipv6 = kwargs.pop("ipv6", False)
if ipv6:
self.address_family = socket.AF_INET6
kwargs["bind_and_activate"] = False
else:
self.address_family = socket.AF_INET
socketserver.TCPServer.__init__(self, *args, **kwargs)
if ipv6:
# pylint: disable=no-member
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
try:
self.server_bind()
self.server_activate()
except:
self.server_close()
raise
@mock.patch("socket.socket.bind")
def test_fail_to_bind(self, mock_bind):
mock_bind.side_effect = socket.error
from acme.standalone import BaseDualNetworkedServers
self.assertRaises(socket.error, BaseDualNetworkedServers,
BaseDualNetworkedServersTest.SingleProtocolServer,
("", 0),
socketserver.BaseRequestHandler)
def test_ports_equal(self):
from acme.standalone import BaseDualNetworkedServers
servers = BaseDualNetworkedServers(
BaseDualNetworkedServersTest.SingleProtocolServer,
("", 0),
socketserver.BaseRequestHandler)
socknames = servers.getsocknames()
prev_port = None
# assert ports are equal
for sockname in socknames:
port = sockname[1]
if prev_port:
self.assertEqual(prev_port, port)
prev_port = port
class TLSSNI01DualNetworkedServersTest(unittest.TestCase):
"""Test for acme.standalone.TLSSNI01DualNetworkedServers."""
_multiprocess_can_split_ = True
def setUp(self):
self.certs = {b'localhost': (
test_util.load_pyopenssl_private_key('rsa2048_key.pem'),
test_util.load_cert('rsa2048_cert.pem'),
)}
from acme.standalone import TLSSNI01DualNetworkedServers
self.servers = TLSSNI01DualNetworkedServers(("", 0), certs=self.certs)
self.servers.serve_forever()
def tearDown(self):
self.servers.shutdown_and_server_close()
def test_connect(self):
socknames = self.servers.getsocknames()
# connect to all addresses
for sockname in socknames:
host, port = sockname[:2]
cert = crypto_util.probe_sni(
b'localhost', host=host, port=port, timeout=1)
self.assertEqual(jose.ComparableX509(cert),
jose.ComparableX509(self.certs[b'localhost'][1]))
class HTTP01DualNetworkedServersTest(unittest.TestCase):
"""Tests for acme.standalone.HTTP01DualNetworkedServers."""
_multiprocess_can_split_ = True
def setUp(self):
self.account_key = jose.JWK.load(
test_util.load_vector('rsa1024_key.pem'))
self.resources = set()
from acme.standalone import HTTP01DualNetworkedServers
self.servers = HTTP01DualNetworkedServers(('', 0), resources=self.resources)
# pylint: disable=no-member
self.port = self.servers.getsocknames()[0][1]
self.servers.serve_forever()
def tearDown(self):
self.servers.shutdown_and_server_close()
def test_index(self):
response = requests.get(
'http://localhost:{0}'.format(self.port), verify=False)
self.assertEqual(
response.text, 'ACME client standalone challenge solver')
self.assertTrue(response.ok)
def test_404(self):
response = requests.get(
'http://localhost:{0}/foo'.format(self.port), verify=False)
self.assertEqual(response.status_code, http_client.NOT_FOUND)
def _test_http01(self, add):
chall = challenges.HTTP01(token=(b'x' * 16))
response, validation = chall.response_and_validation(self.account_key)
from acme.standalone import HTTP01RequestHandler
resource = HTTP01RequestHandler.HTTP01Resource(
chall=chall, response=response, validation=validation)
if add:
self.resources.add(resource)
return resource.response.simple_verify(
resource.chall, 'localhost', self.account_key.public_key(),
port=self.port)
def test_http01_found(self):
self.assertTrue(self._test_http01(add=True))
def test_http01_not_found(self):
self.assertFalse(self._test_http01(add=False))
class TestSimpleTLSSNI01Server(unittest.TestCase):
"""Tests for acme.standalone.simple_tls_sni_01_server."""

View file

@ -3,7 +3,6 @@ import argparse
import collections
import logging
import socket
import threading
import OpenSSL
import six
@ -33,8 +32,6 @@ class ServerManager(object):
will serve the same URLs!
"""
_Instance = collections.namedtuple("_Instance", "server thread")
def __init__(self, certs, http_01_resources):
self._instances = {}
self.certs = certs
@ -51,34 +48,32 @@ class ServerManager(object):
either `acme.challenge.HTTP01` or `acme.challenges.TLSSNI01`.
:param str listenaddr: (optional) The address to listen on. Defaults to all addrs.
:returns: Server instance.
:returns: DualNetworkedServers instance.
:rtype: ACMEServerMixin
"""
assert challenge_type in (challenges.TLSSNI01, challenges.HTTP01)
if port in self._instances:
return self._instances[port].server
return self._instances[port]
address = (listenaddr, port)
try:
if challenge_type is challenges.TLSSNI01:
server = acme_standalone.TLSSNI01Server(address, self.certs)
servers = acme_standalone.TLSSNI01DualNetworkedServers(address, self.certs)
else: # challenges.HTTP01
server = acme_standalone.HTTP01Server(
servers = acme_standalone.HTTP01DualNetworkedServers(
address, self.http_01_resources)
except socket.error as error:
raise errors.StandaloneBindError(error, port)
thread = threading.Thread(
# pylint: disable=no-member
target=server.serve_forever)
thread.start()
servers.serve_forever()
# if port == 0, then random free port on OS is taken
# pylint: disable=no-member
real_port = server.socket.getsockname()[1]
self._instances[real_port] = self._Instance(server, thread)
return server
# both servers, if they exist, have the same port
real_port = servers.getsocknames()[0][1]
self._instances[real_port] = servers
return servers
def stop(self, port):
"""Stop ACME server running on the specified ``port``.
@ -87,13 +82,12 @@ class ServerManager(object):
"""
instance = self._instances[port]
logger.debug("Stopping server at %s:%d...",
*instance.server.socket.getsockname()[:2])
instance.server.shutdown()
for sockname in instance.getsocknames():
logger.debug("Stopping server at %s:%d...",
*sockname[:2])
# Not calling server_close causes problems when renewing multiple
# certs with `certbot renew` using TLSSNI01 and PyOpenSSL 0.13
instance.server.server_close()
instance.thread.join()
instance.shutdown_and_server_close()
del self._instances[port]
def running(self):
@ -102,12 +96,11 @@ class ServerManager(object):
Once the server is stopped using `stop`, it will not be
returned.
:returns: Mapping from ``port`` to ``server``.
:returns: Mapping from ``port`` to ``servers``.
:rtype: tuple
"""
return dict((port, instance.server) for port, instance
in six.iteritems(self._instances))
return self._instances.copy()
SUPPORTED_CHALLENGES = [challenges.TLSSNI01, challenges.HTTP01]
@ -236,38 +229,38 @@ class Authenticator(common.Plugin):
def _perform_single(self, achall):
if isinstance(achall.chall, challenges.HTTP01):
server, response = self._perform_http_01(achall)
servers, response = self._perform_http_01(achall)
else: # tls-sni-01
server, response = self._perform_tls_sni_01(achall)
self.served[server].add(achall)
servers, response = self._perform_tls_sni_01(achall)
self.served[servers].add(achall)
return response
def _perform_http_01(self, achall):
port = self.config.http01_port
addr = self.config.http01_address
server = self.servers.run(port, challenges.HTTP01, listenaddr=addr)
servers = self.servers.run(port, challenges.HTTP01, listenaddr=addr)
response, validation = achall.response_and_validation()
resource = acme_standalone.HTTP01RequestHandler.HTTP01Resource(
chall=achall.chall, response=response, validation=validation)
self.http_01_resources.add(resource)
return server, response
return servers, response
def _perform_tls_sni_01(self, achall):
port = self.config.tls_sni_01_port
addr = self.config.tls_sni_01_address
server = self.servers.run(port, challenges.TLSSNI01, listenaddr=addr)
servers = self.servers.run(port, challenges.TLSSNI01, listenaddr=addr)
response, (cert, _) = achall.response_and_validation(cert_key=self.key)
self.certs[response.z_domain] = (self.key, cert)
return server, response
return servers, response
def cleanup(self, achalls): # pylint: disable=missing-docstring
# reduce self.served and close servers if none challenges are served
for server, server_achalls in self.served.items():
# reduce self.served and close servers if no challenges are served
for unused_servers, server_achalls in self.served.items():
for achall in achalls:
if achall in server_achalls:
server_achalls.remove(achall)
for port, server in six.iteritems(self.servers.running()):
if not self.served[server]:
for port, servers in six.iteritems(self.servers.running()):
if not self.served[servers]:
self.servers.stop(port)

View file

@ -32,7 +32,7 @@ class ServerManagerTest(unittest.TestCase):
def _test_run_stop(self, challenge_type):
server = self.mgr.run(port=0, challenge_type=challenge_type)
port = server.socket.getsockname()[1] # pylint: disable=no-member
port = server.getsocknames()[0][1] # pylint: disable=no-member
self.assertEqual(self.mgr.running(), {port: server})
self.mgr.stop(port=port)
self.assertEqual(self.mgr.running(), {})
@ -45,7 +45,7 @@ class ServerManagerTest(unittest.TestCase):
def test_run_idempotent(self):
server = self.mgr.run(port=0, challenge_type=challenges.HTTP01)
port = server.socket.getsockname()[1] # pylint: disable=no-member
port = server.getsocknames()[0][1] # pylint: disable=no-member
server2 = self.mgr.run(port=port, challenge_type=challenges.HTTP01)
self.assertEqual(self.mgr.running(), {port: server})
self.assertTrue(server is server2)
@ -53,9 +53,14 @@ class ServerManagerTest(unittest.TestCase):
self.assertEqual(self.mgr.running(), {})
def test_run_bind_error(self):
some_server = socket.socket()
some_server = socket.socket(socket.AF_INET6)
some_server.bind(("", 0))
port = some_server.getsockname()[1]
maybe_another_server = socket.socket()
try:
maybe_another_server.bind(("", port))
except socket.error:
pass
self.assertRaises(
errors.StandaloneBindError, self.mgr.run, port,
challenge_type=challenges.HTTP01)