From 737441f201d70da9ef08b254641e154a92673079 Mon Sep 17 00:00:00 2001 From: Gunnar Beutner Date: Wed, 23 Apr 2014 13:42:59 +0200 Subject: [PATCH] Fix deadlock in TlsStream::{Read,Write}. Refs #5467 --- components/agent/agentlistener.cpp | 4 +- components/cluster/clusterlistener.cpp | 4 +- lib/base/socket.cpp | 46 ++++++++++++++++++++++ lib/base/socket.h | 7 +++- lib/base/tlsstream.cpp | 53 +++++++++++++++++--------- lib/base/tlsstream.h | 6 +-- 6 files changed, 93 insertions(+), 27 deletions(-) diff --git a/components/agent/agentlistener.cpp b/components/agent/agentlistener.cpp index 0da7c9100..dfc99198c 100644 --- a/components/agent/agentlistener.cpp +++ b/components/agent/agentlistener.cpp @@ -144,13 +144,11 @@ void AgentListener::NewClientHandler(const Socket::Ptr& client, TlsRole role) { CONTEXT("Handling new agent client connection"); - NetworkStream::Ptr netStream = make_shared(client); - TlsStream::Ptr tlsStream; { ObjectLock olock(this); - tlsStream = make_shared(netStream, role, m_SSLContext); + tlsStream = make_shared(client, role, m_SSLContext); } tlsStream->Handshake(); diff --git a/components/cluster/clusterlistener.cpp b/components/cluster/clusterlistener.cpp index 807a48767..df77c0464 100644 --- a/components/cluster/clusterlistener.cpp +++ b/components/cluster/clusterlistener.cpp @@ -526,9 +526,7 @@ void ClusterListener::NewClientHandler(const Socket::Ptr& client, TlsRole role) { CONTEXT("Handling new cluster client connection"); - NetworkStream::Ptr netStream = make_shared(client); - - TlsStream::Ptr tlsStream = make_shared(netStream, role, m_SSLContext); + TlsStream::Ptr tlsStream = make_shared(client, role, m_SSLContext); tlsStream->Handshake(); shared_ptr cert = tlsStream->GetPeerCertificate(); diff --git a/lib/base/socket.cpp b/lib/base/socket.cpp index c713b9937..1502df542 100644 --- a/lib/base/socket.cpp +++ b/lib/base/socket.cpp @@ -28,6 +28,10 @@ #include #include +#ifndef _WIN32 +# include +#endif /* _WIN32 */ + using namespace icinga; /** @@ -284,3 +288,45 @@ Socket::Ptr Socket::Accept(void) return make_shared(fd); } + +void Socket::Poll(bool read, bool write) +{ +#ifdef _WIN32 + fd_set readfds, writefds, exceptfds; + + FD_ZERO(&readfds); + if (read) + FD_SET(GetFD(), &readfds); + + FD_ZERO(&writefds); + if (write) + FD_SET(GetFD(), &writefds); + + FD_ZERO(&exceptfds); + FD_SET(GetFD(), &exceptfds); + + if (select(GetFD() + 1, &readfds, &writefds, &exceptfds, NULL) < 0) + BOOST_THROW_EXCEPTION(socket_error() + << boost::errinfo_api_function("select") + << errinfo_win32_error(WSAGetLastError())); +#else /* _WIN32 */ + pollfd pfd; + pfd.fd = GetFD(); + pfd.events = (read ? POLLIN : 0) | (write ? POLLOUT : 0); + pfd.revents = 0; + + if (poll(&pfd, 1, -1) < 0) + BOOST_THROW_EXCEPTION(socket_error() + << boost::errinfo_api_function("poll") + << boost::errinfo_errno(errno)); +#endif /* _WIN32 */ +} + +void Socket::MakeNonBlocking(void) +{ +#ifdef _WIN32 + Utility::SetNonBlockingSocket(GetFD()); +#else /* _WIN32 */ + Utility::SetNonBlocking(GetFD()); +#endif /* _WIN32 */ +} \ No newline at end of file diff --git a/lib/base/socket.h b/lib/base/socket.h index 85e757af3..e8c008ac0 100644 --- a/lib/base/socket.h +++ b/lib/base/socket.h @@ -43,6 +43,8 @@ public: Socket(SOCKET fd); ~Socket(void); + SOCKET GetFD(void) const; + void Close(void); String GetClientAddress(void); @@ -54,9 +56,12 @@ public: void Listen(void); Socket::Ptr Accept(void); + void Poll(bool read, bool write); + + void MakeNonBlocking(void); + protected: void SetFD(SOCKET fd); - SOCKET GetFD(void) const; int GetError(void) const; diff --git a/lib/base/tlsstream.cpp b/lib/base/tlsstream.cpp index 0f76f6275..934be2cd0 100644 --- a/lib/base/tlsstream.cpp +++ b/lib/base/tlsstream.cpp @@ -18,7 +18,6 @@ ******************************************************************************/ #include "base/tlsstream.h" -#include "base/stream_bio.h" #include "base/objectlock.h" #include "base/debug.h" #include "base/utility.h" @@ -37,8 +36,8 @@ bool I2_EXPORT TlsStream::m_SSLIndexInitialized = false; * @param role The role of the client. * @param sslContext The SSL context for the client. */ -TlsStream::TlsStream(const Stream::Ptr& innerStream, TlsRole role, shared_ptr sslContext) - : m_InnerStream(innerStream), m_Role(role) +TlsStream::TlsStream(const Socket::Ptr& socket, TlsRole role, shared_ptr sslContext) + : m_Socket(socket), m_Role(role) { m_SSL = shared_ptr(SSL_new(sslContext.get()), SSL_free); @@ -57,7 +56,10 @@ TlsStream::TlsStream(const Stream::Ptr& innerStream, TlsRole role, shared_ptrMakeNonBlocking(); + + m_BIO = BIO_new_socket(socket->GetFD(), 0); + BIO_set_nbio(m_BIO, 1); SSL_set_bio(m_SSL.get(), m_BIO, m_BIO); if (m_Role == TlsRoleServer) @@ -92,19 +94,28 @@ void TlsStream::Handshake(void) int rc; - ObjectLock olock(this); + for (;;) { + int rc; + + { + ObjectLock olock(this); + rc = SSL_do_handshake(m_SSL.get()); + } + + if (rc > 0) + break; - while ((rc = SSL_do_handshake(m_SSL.get())) <= 0) { switch (SSL_get_error(m_SSL.get(), rc)) { case SSL_ERROR_WANT_READ: + m_Socket->Poll(true, false); continue; case SSL_ERROR_WANT_WRITE: + m_Socket->Poll(false, true); continue; case SSL_ERROR_ZERO_RETURN: Close(); return; default: - I2Stream_check_exception(m_BIO); BOOST_THROW_EXCEPTION(openssl_error() << boost::errinfo_api_function("SSL_do_handshake") << errinfo_openssl_error(ERR_get_error())); @@ -121,22 +132,26 @@ size_t TlsStream::Read(void *buffer, size_t count) size_t left = count; - ObjectLock olock(this); - while (left > 0) { - int rc = SSL_read(m_SSL.get(), ((char *)buffer) + (count - left), left); + int rc; + + { + ObjectLock olock(this); + rc = SSL_read(m_SSL.get(), ((char *)buffer) + (count - left), left); + } if (rc <= 0) { switch (SSL_get_error(m_SSL.get(), rc)) { case SSL_ERROR_WANT_READ: + m_Socket->Poll(true, false); continue; case SSL_ERROR_WANT_WRITE: + m_Socket->Poll(false, true); continue; case SSL_ERROR_ZERO_RETURN: Close(); return count - left; default: - I2Stream_check_exception(m_BIO); BOOST_THROW_EXCEPTION(openssl_error() << boost::errinfo_api_function("SSL_read") << errinfo_openssl_error(ERR_get_error())); @@ -155,22 +170,26 @@ void TlsStream::Write(const void *buffer, size_t count) size_t left = count; - ObjectLock olock(this); - while (left > 0) { - int rc = SSL_write(m_SSL.get(), ((const char *)buffer) + (count - left), left); + int rc; + + { + ObjectLock olock(this); + rc = SSL_write(m_SSL.get(), ((const char *)buffer) + (count - left), left); + } if (rc <= 0) { switch (SSL_get_error(m_SSL.get(), rc)) { case SSL_ERROR_WANT_READ: + m_Socket->Poll(true, false); continue; case SSL_ERROR_WANT_WRITE: + m_Socket->Poll(false, true); continue; case SSL_ERROR_ZERO_RETURN: Close(); return; default: - I2Stream_check_exception(m_BIO); BOOST_THROW_EXCEPTION(openssl_error() << boost::errinfo_api_function("SSL_write") << errinfo_openssl_error(ERR_get_error())); @@ -186,10 +205,10 @@ void TlsStream::Write(const void *buffer, size_t count) */ void TlsStream::Close(void) { - m_InnerStream->Close(); + m_Socket->Close(); } bool TlsStream::IsEof(void) const { - return m_InnerStream->IsEof(); + return BIO_eof(m_BIO); } diff --git a/lib/base/tlsstream.h b/lib/base/tlsstream.h index 0d60bacef..55685c9b9 100644 --- a/lib/base/tlsstream.h +++ b/lib/base/tlsstream.h @@ -21,7 +21,7 @@ #define TLSSTREAM_H #include "base/i2-base.h" -#include "base/stream.h" +#include "base/socket.h" #include "base/fifo.h" #include "base/tlsutility.h" @@ -44,7 +44,7 @@ class I2_BASE_API TlsStream : public Stream public: DECLARE_PTR_TYPEDEFS(TlsStream); - TlsStream(const Stream::Ptr& innerStream, TlsRole role, shared_ptr sslContext); + TlsStream(const Socket::Ptr& socket, TlsRole role, shared_ptr sslContext); shared_ptr GetClientCertificate(void) const; shared_ptr GetPeerCertificate(void) const; @@ -62,7 +62,7 @@ private: shared_ptr m_SSL; BIO *m_BIO; - Stream::Ptr m_InnerStream; + Socket::Ptr m_Socket; TlsRole m_Role; static int m_SSLIndex;