Remove serve_forever2/shutdown2 (reduces probability of #1085).

I'm not even sure why `serve_forever2` and `shutdown2` were introduced
in the first place... It probably follows from my misconception about
the SocketServer module. After having studied the module again, I come
to the conclusion that we can get rid of my crap, simultanously
reducing probability of #1085 (hopefully down to 0)!

`server_forever` is used throughout tests instead of `handle_request`,
because `shutdown`, following docs, "must be called while
serve_forever() is running in another thread, or it will deadlock",
and our `probe_sni` HTTP request is already enough to kill single
`handle_request`.

We don't need to use any busy waiting block or `sleep` between serve
and shutdown; studying CPython source code leads to the conclusion
that the following construction is non-blocking:

```python
import threading, SocketServer
s = SocketServer.TCPServer(("", 0), None)
t = threading.Thread(target=s.shutdown)
t.start()
s.serve_forever()  # returns immediately
t.join()  # returns immediately
```
This commit is contained in:
Jakub Warmuz 2015-10-29 20:46:43 +00:00
parent 6124571f34
commit 4cc0610679
No known key found for this signature in database
GPG key ID: 2A7BAD3A489B52EA
3 changed files with 10 additions and 89 deletions

View file

@ -4,7 +4,6 @@ import collections
import functools
import logging
import os
import socket
import sys
import six
@ -50,37 +49,11 @@ class ACMEServerMixin: # pylint: disable=old-style-class
server_version = "ACME client standalone challenge solver"
allow_reuse_address = True
def __init__(self):
self._stopped = False
def serve_forever2(self):
"""Serve forever, until other thread calls `shutdown2`."""
logger.debug("Starting server at %s:%d...",
*self.socket.getsockname()[:2])
while not self._stopped:
self.handle_request()
def shutdown2(self):
"""Shutdown server loop from `serve_forever2`."""
self._stopped = True
# dummy request to terminate last server_forever2.handle_request()
sock = socket.socket()
try:
sock.connect(self.socket.getsockname())
except socket.error:
pass # thread is probably already finished
finally:
sock.close()
self.server_close()
class DVSNIServer(TLSServer, ACMEServerMixin):
"""DVSNI Server."""
def __init__(self, server_address, certs):
ACMEServerMixin.__init__(self)
TLSServer.__init__(
self, server_address, socketserver.BaseRequestHandler, certs=certs)
@ -89,7 +62,6 @@ class SimpleHTTPServer(BaseHTTPServer.HTTPServer, ACMEServerMixin):
"""SimpleHTTP Server."""
def __init__(self, server_address, resources):
ACMEServerMixin.__init__(self)
BaseHTTPServer.HTTPServer.__init__(
self, server_address, SimpleHTTPRequestHandler.partial_init(
simple_http_resources=resources))

View file

@ -1,7 +1,6 @@
"""Tests for acme.standalone."""
import os
import shutil
import socket
import threading
import tempfile
import time
@ -29,54 +28,6 @@ class TLSServerTest(unittest.TestCase):
server.server_close() # pylint: disable=no-member
class ACMEServerMixinTest(unittest.TestCase):
"""Tests for acme.standalone.ACMEServerMixin."""
def setUp(self):
from acme.standalone import ACMEServerMixin
class _MockHandler(socketserver.BaseRequestHandler):
# pylint: disable=missing-docstring,no-member,no-init
def handle(self):
self.request.sendall(b"DONE")
class _MockServer(socketserver.TCPServer, ACMEServerMixin):
def __init__(self, *args, **kwargs):
socketserver.TCPServer.__init__(self, *args, **kwargs)
ACMEServerMixin.__init__(self)
self.server = _MockServer(("", 0), _MockHandler)
def _busy_wait(self): # pragma: no cover
# This function is used to avoid race conditions in tests, but
# not all of the functionality is always used, hence "no
# cover"
while True:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
# pylint: disable=no-member
sock.connect(self.server.socket.getsockname())
except socket.error:
pass
else:
sock.recv(4) # wait until handle_request is actually called
break
finally:
sock.close()
time.sleep(1)
def test_serve_shutdown(self):
thread = threading.Thread(target=self.server.serve_forever2)
thread.start()
self._busy_wait()
self.server.shutdown2()
def test_shutdown2_not_running(self):
self.server.shutdown2()
self.server.shutdown2()
class DVSNIServerTest(unittest.TestCase):
"""Test for acme.standalone.DVSNIServer."""
@ -89,20 +40,16 @@ class DVSNIServerTest(unittest.TestCase):
from acme.standalone import DVSNIServer
self.server = DVSNIServer(("", 0), certs=self.certs)
# pylint: disable=no-member
self.thread = threading.Thread(target=self.server.handle_request)
self.thread = threading.Thread(target=self.server.serve_forever)
self.thread.start()
def tearDown(self):
self.server.shutdown2()
self.server.shutdown() # pylint: disable=no-member
self.thread.join()
def test_init(self):
# pylint: disable=protected-access
self.assertFalse(self.server._stopped)
def test_dvsni(self):
def test_it(self):
host, port = self.server.socket.getsockname()[:2]
cert = crypto_util.probe_sni(b'localhost', host=host, port=port)
cert = crypto_util.probe_sni(b'localhost', host=host, port=port, timeout=1)
self.assertEqual(jose.ComparableX509(cert),
jose.ComparableX509(self.certs[b'localhost'][1]))
@ -120,11 +67,11 @@ class SimpleHTTPServerTest(unittest.TestCase):
# pylint: disable=no-member
self.port = self.server.socket.getsockname()[1]
self.thread = threading.Thread(target=self.server.handle_request)
self.thread = threading.Thread(target=self.server.serve_forever)
self.thread.start()
def tearDown(self):
self.server.shutdown2()
self.server.shutdown() # pylint: disable=no-member
self.thread.join()
def test_index(self):

View file

@ -72,7 +72,9 @@ class ServerManager(object):
except socket.error as error:
raise errors.StandaloneBindError(error, port)
thread = threading.Thread(target=server.serve_forever2)
thread = threading.Thread(
# pylint: disable=no-member
target=server.serve_forever)
thread.start()
# if port == 0, then random free port on OS is taken
@ -90,7 +92,7 @@ class ServerManager(object):
instance = self._instances[port]
logger.debug("Stopping server at %s:%d...",
*instance.server.socket.getsockname()[:2])
instance.server.shutdown2()
instance.server.shutdown()
instance.thread.join()
del self._instances[port]