diff --git a/webrtc/rtc_base/openssladapter.cc b/webrtc/rtc_base/openssladapter.cc index eec80216da..11473ac225 100644 --- a/webrtc/rtc_base/openssladapter.cc +++ b/webrtc/rtc_base/openssladapter.cc @@ -274,8 +274,10 @@ bool OpenSSLAdapter::CleanupSSL() { return true; } -OpenSSLAdapter::OpenSSLAdapter(AsyncSocket* socket) +OpenSSLAdapter::OpenSSLAdapter(AsyncSocket* socket, + OpenSSLAdapterFactory* factory) : SSLAdapter(socket), + factory_(factory), state_(SSL_NONE), ssl_read_needs_write_(false), ssl_write_needs_read_(false), @@ -283,20 +285,29 @@ OpenSSLAdapter::OpenSSLAdapter(AsyncSocket* socket) ssl_(nullptr), ssl_ctx_(nullptr), ssl_mode_(SSL_MODE_TLS), - custom_verification_succeeded_(false) {} + custom_verification_succeeded_(false) { + // If a factory is used, take a reference on the factory's SSL_CTX. + // Otherwise, we'll create our own later. + // Either way, we'll release our reference via SSL_CTX_free() in Cleanup(). + if (factory_) { + ssl_ctx_ = factory_->ssl_ctx(); + RTC_DCHECK(ssl_ctx_); + // Note: if using OpenSSL, requires version 1.1.0 or later. + SSL_CTX_up_ref(ssl_ctx_); + } +} OpenSSLAdapter::~OpenSSLAdapter() { Cleanup(); } -void -OpenSSLAdapter::SetMode(SSLMode mode) { +void OpenSSLAdapter::SetMode(SSLMode mode) { + RTC_DCHECK(!ssl_ctx_); RTC_DCHECK(state_ == SSL_NONE); ssl_mode_ = mode; } -int -OpenSSLAdapter::StartSSL(const char* hostname, bool restartable) { +int OpenSSLAdapter::StartSSL(const char* hostname, bool restartable) { if (state_ != SSL_NONE) return -1; @@ -317,18 +328,20 @@ OpenSSLAdapter::StartSSL(const char* hostname, bool restartable) { return 0; } -int -OpenSSLAdapter::BeginSSL() { - LOG(LS_INFO) << "BeginSSL: " << ssl_host_name_; +int OpenSSLAdapter::BeginSSL() { + LOG(LS_INFO) << "OpenSSLAdapter::BeginSSL: " << ssl_host_name_; RTC_DCHECK(state_ == SSL_CONNECTING); int err = 0; BIO* bio = nullptr; - // First set up the context - if (!ssl_ctx_) - ssl_ctx_ = SetupSSLContext(); - + // First set up the context. We should either have a factory, with its own + // pre-existing context, or be running standalone, in which case we will + // need to create one, and specify |false| to disable session caching. + if (!factory_) { + RTC_DCHECK(!ssl_ctx_); + ssl_ctx_ = CreateContext(ssl_mode_, false); + } if (!ssl_ctx_) { err = -1; goto ssl_error; @@ -348,7 +361,6 @@ OpenSSLAdapter::BeginSSL() { SSL_set_app_data(ssl_, this); - SSL_set_bio(ssl_, bio, bio); // SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER allows different buffers to be passed // into SSL_write when a record could only be partially transmitted (and thus // requires another call to SSL_write to finish transmission). This allows us @@ -360,9 +372,24 @@ OpenSSLAdapter::BeginSSL() { SSL_set_mode(ssl_, SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); - // Enable SNI. + // Enable SNI, if a hostname is supplied. if (!ssl_host_name_.empty()) { SSL_set_tlsext_host_name(ssl_, ssl_host_name_.c_str()); + + // Enable session caching, if configured and a hostname is supplied. + if (factory_) { + SSL_SESSION* cached = factory_->LookupSession(ssl_host_name_); + if (cached) { + if (SSL_set_session(ssl_, cached) == 0) { + LOG(LS_WARNING) << "Failed to apply SSL session from cache"; + err = -1; + goto ssl_error; + } + + LOG(LS_INFO) << "Attempting to resume SSL session to " + << ssl_host_name_; + } + } } // Set a couple common TLS extensions; even though we don't use them yet. @@ -370,10 +397,12 @@ OpenSSLAdapter::BeginSSL() { SSL_enable_ocsp_stapling(ssl_); SSL_enable_signed_cert_timestamps(ssl_); - // the SSL object owns the bio now + // Now that the initial config is done, transfer ownership of |bio| to the + // SSL object. If ContinueSSL() fails, the bio will be freed in Cleanup(). + SSL_set_bio(ssl_, bio, bio); bio = nullptr; - // Do the connect + // Do the connect. err = ContinueSSL(); if (err != 0) goto ssl_error; @@ -388,8 +417,7 @@ ssl_error: return err; } -int -OpenSSLAdapter::ContinueSSL() { +int OpenSSLAdapter::ContinueSSL() { RTC_DCHECK(state_ == SSL_CONNECTING); // Clear the DTLS timer @@ -441,8 +469,7 @@ OpenSSLAdapter::ContinueSSL() { return 0; } -void -OpenSSLAdapter::Error(const char* context, int err, bool signal) { +void OpenSSLAdapter::Error(const char* context, int err, bool signal) { LOG(LS_WARNING) << "OpenSSLAdapter::Error(" << context << ", " << err << ")"; state_ = SSL_ERROR; @@ -451,9 +478,8 @@ OpenSSLAdapter::Error(const char* context, int err, bool signal) { AsyncSocketAdapter::OnCloseEvent(this, err); } -void -OpenSSLAdapter::Cleanup() { - LOG(LS_INFO) << "Cleanup"; +void OpenSSLAdapter::Cleanup() { + LOG(LS_INFO) << "OpenSSLAdapter::Cleanup"; state_ = SSL_NONE; ssl_read_needs_write_ = false; @@ -519,8 +545,7 @@ int OpenSSLAdapter::DoSslWrite(const void* pv, size_t cb, int* error) { // AsyncSocket Implementation // -int -OpenSSLAdapter::Send(const void* pv, size_t cb) { +int OpenSSLAdapter::Send(const void* pv, size_t cb) { //LOG(LS_INFO) << "OpenSSLAdapter::Send(" << cb << ")"; switch (state_) { @@ -589,8 +614,9 @@ OpenSSLAdapter::Send(const void* pv, size_t cb) { return ret; } -int -OpenSSLAdapter::SendTo(const void* pv, size_t cb, const SocketAddress& addr) { +int OpenSSLAdapter::SendTo(const void* pv, + size_t cb, + const SocketAddress& addr) { if (socket_->GetState() == Socket::CS_CONNECTED && addr == socket_->GetRemoteAddress()) { return Send(pv, cb); @@ -677,15 +703,13 @@ int OpenSSLAdapter::RecvFrom(void* pv, return SOCKET_ERROR; } -int -OpenSSLAdapter::Close() { +int OpenSSLAdapter::Close() { Cleanup(); state_ = restartable_ ? SSL_WAIT : SSL_NONE; return AsyncSocketAdapter::Close(); } -Socket::ConnState -OpenSSLAdapter::GetState() const { +Socket::ConnState OpenSSLAdapter::GetState() const { //if (signal_close_) // return CS_CONNECTED; ConnState state = socket_->GetState(); @@ -695,8 +719,11 @@ OpenSSLAdapter::GetState() const { return state; } -void -OpenSSLAdapter::OnMessage(Message* msg) { +bool OpenSSLAdapter::IsResumedSession() { + return (ssl_ && SSL_session_reused(ssl_) == 1); +} + +void OpenSSLAdapter::OnMessage(Message* msg) { if (MSG_TIMEOUT == msg->message_id) { LOG(LS_INFO) << "DTLS timeout expired"; DTLSv1_handle_timeout(ssl_); @@ -704,8 +731,7 @@ OpenSSLAdapter::OnMessage(Message* msg) { } } -void -OpenSSLAdapter::OnConnectEvent(AsyncSocket* socket) { +void OpenSSLAdapter::OnConnectEvent(AsyncSocket* socket) { LOG(LS_INFO) << "OpenSSLAdapter::OnConnectEvent"; if (state_ != SSL_WAIT) { RTC_DCHECK(state_ == SSL_NONE); @@ -719,8 +745,7 @@ OpenSSLAdapter::OnConnectEvent(AsyncSocket* socket) { } } -void -OpenSSLAdapter::OnReadEvent(AsyncSocket* socket) { +void OpenSSLAdapter::OnReadEvent(AsyncSocket* socket) { //LOG(LS_INFO) << "OpenSSLAdapter::OnReadEvent"; if (state_ == SSL_NONE) { @@ -749,8 +774,7 @@ OpenSSLAdapter::OnReadEvent(AsyncSocket* socket) { AsyncSocketAdapter::OnReadEvent(socket); } -void -OpenSSLAdapter::OnWriteEvent(AsyncSocket* socket) { +void OpenSSLAdapter::OnWriteEvent(AsyncSocket* socket) { //LOG(LS_INFO) << "OpenSSLAdapter::OnWriteEvent"; if (state_ == SSL_NONE) { @@ -790,8 +814,7 @@ OpenSSLAdapter::OnWriteEvent(AsyncSocket* socket) { AsyncSocketAdapter::OnWriteEvent(socket); } -void -OpenSSLAdapter::OnCloseEvent(AsyncSocket* socket, int err) { +void OpenSSLAdapter::OnCloseEvent(AsyncSocket* socket, int err) { LOG(LS_INFO) << "OpenSSLAdapter::OnCloseEvent(" << err << ")"; AsyncSocketAdapter::OnCloseEvent(socket, err); } @@ -891,8 +914,7 @@ bool OpenSSLAdapter::SSLPostConnectionCheck(SSL* ssl, const char* host) { // We only use this for tracing and so it is only needed in debug mode -void -OpenSSLAdapter::SSLInfoCallback(const SSL* s, int where, int ret) { +void OpenSSLAdapter::SSLInfoCallback(const SSL* s, int where, int ret) { const char* str = "undefined"; int w = where & ~SSL_ST_MASK; if (w & SSL_ST_CONNECT) { @@ -918,8 +940,7 @@ OpenSSLAdapter::SSLInfoCallback(const SSL* s, int where, int ret) { #endif -int -OpenSSLAdapter::SSLVerifyCallback(int ok, X509_STORE_CTX* store) { +int OpenSSLAdapter::SSLVerifyCallback(int ok, X509_STORE_CTX* store) { #if !defined(NDEBUG) if (!ok) { char data[256]; @@ -964,6 +985,15 @@ OpenSSLAdapter::SSLVerifyCallback(int ok, X509_STORE_CTX* store) { return ok; } +int OpenSSLAdapter::NewSSLSessionCallback(SSL* ssl, SSL_SESSION* session) { + OpenSSLAdapter* stream = + reinterpret_cast(SSL_get_app_data(ssl)); + RTC_DCHECK(stream->factory_); + LOG(LS_INFO) << "Caching SSL session for " << stream->ssl_host_name_; + stream->factory_->AddSession(stream->ssl_host_name_, session); + return 1; // We've taken ownership of the session; OpenSSL shouldn't free it. +} + bool OpenSSLAdapter::ConfigureTrustedRootCertificates(SSL_CTX* ctx) { // Add the root cert that we care about to the SSL context int count_of_added_certs = 0; @@ -985,18 +1015,17 @@ bool OpenSSLAdapter::ConfigureTrustedRootCertificates(SSL_CTX* ctx) { return count_of_added_certs > 0; } -SSL_CTX* -OpenSSLAdapter::SetupSSLContext() { +SSL_CTX* OpenSSLAdapter::CreateContext(SSLMode mode, bool enable_cache) { // Use (D)TLS 1.2. // Note: BoringSSL supports a range of versions by setting max/min version // (Default V1.0 to V1.2). However (D)TLSv1_2_client_method functions used // below in OpenSSL only support V1.2. SSL_CTX* ctx = nullptr; #ifdef OPENSSL_IS_BORINGSSL - ctx = SSL_CTX_new(ssl_mode_ == SSL_MODE_DTLS ? DTLS_method() : TLS_method()); + ctx = SSL_CTX_new(mode == SSL_MODE_DTLS ? DTLS_method() : TLS_method()); #else - ctx = SSL_CTX_new(ssl_mode_ == SSL_MODE_DTLS ? DTLSv1_2_client_method() - : TLSv1_2_client_method()); + ctx = SSL_CTX_new(mode == SSL_MODE_DTLS ? DTLSv1_2_client_method() + : TLSv1_2_client_method()); #endif // OPENSSL_IS_BORINGSSL if (ctx == nullptr) { unsigned long error = ERR_get_error(); // NOLINT: type used by OpenSSL. @@ -1023,11 +1052,59 @@ OpenSSLAdapter::SetupSSLContext() { SSL_CTX_set_cipher_list( ctx, "ALL:!SHA256:!SHA384:!aPSK:!ECDSA+SHA1:!ADH:!LOW:!EXP:!MD5"); - if (ssl_mode_ == SSL_MODE_DTLS) { + if (mode == SSL_MODE_DTLS) { SSL_CTX_set_read_ahead(ctx, 1); } + if (enable_cache) { + SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_CLIENT); + SSL_CTX_sess_set_new_cb(ctx, &OpenSSLAdapter::NewSSLSessionCallback); + } + return ctx; } +////////////////////////////////////////////////////////////////////// +// OpenSSLAdapterFactory +////////////////////////////////////////////////////////////////////// + +OpenSSLAdapterFactory::OpenSSLAdapterFactory() + : ssl_mode_(SSL_MODE_TLS), ssl_ctx_(nullptr) {} + +OpenSSLAdapterFactory::~OpenSSLAdapterFactory() { + for (auto it : sessions_) { + SSL_SESSION_free(it.second); + } + SSL_CTX_free(ssl_ctx_); +} + +void OpenSSLAdapterFactory::SetMode(SSLMode mode) { + RTC_DCHECK(!ssl_ctx_); + ssl_mode_ = mode; +} + +OpenSSLAdapter* OpenSSLAdapterFactory::CreateAdapter(AsyncSocket* socket) { + if (!ssl_ctx_) { + bool enable_cache = true; + ssl_ctx_ = OpenSSLAdapter::CreateContext(ssl_mode_, enable_cache); + if (!ssl_ctx_) { + return nullptr; + } + } + + return new OpenSSLAdapter(socket, this); +} + +SSL_SESSION* OpenSSLAdapterFactory::LookupSession(const std::string& hostname) { + auto it = sessions_.find(hostname); + return (it != sessions_.end()) ? it->second : nullptr; +} + +void OpenSSLAdapterFactory::AddSession(const std::string& hostname, + SSL_SESSION* new_session) { + SSL_SESSION* old_session = LookupSession(hostname); + SSL_SESSION_free(old_session); + sessions_[hostname] = new_session; +} + } // namespace rtc diff --git a/webrtc/rtc_base/openssladapter.h b/webrtc/rtc_base/openssladapter.h index 251714545d..4b49efd25f 100644 --- a/webrtc/rtc_base/openssladapter.h +++ b/webrtc/rtc_base/openssladapter.h @@ -11,6 +11,7 @@ #ifndef WEBRTC_RTC_BASE_OPENSSLADAPTER_H_ #define WEBRTC_RTC_BASE_OPENSSLADAPTER_H_ +#include #include #include "webrtc/rtc_base/buffer.h" #include "webrtc/rtc_base/messagehandler.h" @@ -20,18 +21,20 @@ typedef struct ssl_st SSL; typedef struct ssl_ctx_st SSL_CTX; typedef struct x509_store_ctx_st X509_STORE_CTX; +typedef struct ssl_session_st SSL_SESSION; namespace rtc { -/////////////////////////////////////////////////////////////////////////////// +class OpenSSLAdapterFactory; class OpenSSLAdapter : public SSLAdapter, public MessageHandler { -public: + public: static bool InitializeSSL(VerificationCallback callback); static bool InitializeSSLThread(); static bool CleanupSSL(); - OpenSSLAdapter(AsyncSocket* socket); + explicit OpenSSLAdapter(AsyncSocket* socket, + OpenSSLAdapterFactory* factory = nullptr); ~OpenSSLAdapter() override; void SetMode(SSLMode mode) override; @@ -47,14 +50,23 @@ public: // Note that the socket returns ST_CONNECTING while SSL is being negotiated. ConnState GetState() const override; + bool IsResumedSession() override; -protected: - void OnConnectEvent(AsyncSocket* socket) override; - void OnReadEvent(AsyncSocket* socket) override; - void OnWriteEvent(AsyncSocket* socket) override; - void OnCloseEvent(AsyncSocket* socket, int err) override; + // Creates a new SSL_CTX object, configured for client-to-server usage + // with SSLMode |mode|, and if |enable_cache| is true, with support for + // storing successful sessions so that they can be later resumed. + // OpenSSLAdapterFactory will call this method to create its own internal + // SSL_CTX, and OpenSSLAdapter will also call this when used without a + // factory. + static SSL_CTX* CreateContext(SSLMode mode, bool enable_cache); -private: + protected: + void OnConnectEvent(AsyncSocket* socket) override; + void OnReadEvent(AsyncSocket* socket) override; + void OnWriteEvent(AsyncSocket* socket) override; + void OnCloseEvent(AsyncSocket* socket, int err) override; + + private: enum SSLState { SSL_NONE, SSL_WAIT, SSL_CONNECTING, SSL_CONNECTED, SSL_ERROR }; @@ -76,19 +88,29 @@ private: bool ignore_bad_cert); bool SSLPostConnectionCheck(SSL* ssl, const char* host); #if !defined(NDEBUG) - static void SSLInfoCallback(const SSL* s, int where, int ret); + // In debug builds, logs info about the state of the SSL connection. + static void SSLInfoCallback(const SSL* ssl, int where, int ret); #endif static int SSLVerifyCallback(int ok, X509_STORE_CTX* store); static VerificationCallback custom_verify_callback_; friend class OpenSSLStreamAdapter; // for custom_verify_callback_; + // If the SSL_CTX was created with |enable_cache| set to true, this callback + // will be called when a SSL session has been successfully established, + // to allow its SSL_SESSION* to be cached for later resumption. + static int NewSSLSessionCallback(SSL* ssl, SSL_SESSION* session); + static bool ConfigureTrustedRootCertificates(SSL_CTX* ctx); - SSL_CTX* SetupSSLContext(); + + // Parent object that maintains shared state. + // Can be null if state sharing is not needed. + OpenSSLAdapterFactory* factory_; SSLState state_; bool ssl_read_needs_write_; bool ssl_write_needs_read_; // If true, socket will retain SSL configuration after Close. + // TODO(juberti): Remove this unused flag. bool restartable_; // This buffer is used if SSL_write fails with SSL_ERROR_WANT_WRITE, which @@ -105,9 +127,34 @@ private: bool custom_verification_succeeded_; }; -///////////////////////////////////////////////////////////////////////////// +class OpenSSLAdapterFactory : public SSLAdapterFactory { + public: + OpenSSLAdapterFactory(); + ~OpenSSLAdapterFactory() override; -} // namespace rtc + void SetMode(SSLMode mode) override; + OpenSSLAdapter* CreateAdapter(AsyncSocket* socket) override; + static OpenSSLAdapterFactory* Create(); -#endif // WEBRTC_RTC_BASE_OPENSSLADAPTER_H_ + private: + SSL_CTX* ssl_ctx() { return ssl_ctx_; } + // Looks up a session by hostname. The returned SSL_SESSION is not up_refed. + SSL_SESSION* LookupSession(const std::string& hostname); + // Adds a session to the cache, and up_refs it. Any existing session with the + // same hostname is replaced. + void AddSession(const std::string& hostname, SSL_SESSION* session); + friend class OpenSSLAdapter; + + SSLMode ssl_mode_; + // Holds the shared SSL_CTX for all created adapters. + SSL_CTX* ssl_ctx_; + // Map of hostnames to SSL_SESSIONs; holds references to the SSL_SESSIONs, + // which are cleaned up when the factory is destroyed. + // TODO(juberti): Add LRU eviction to keep the cache from growing forever. + std::map sessions_; +}; + +} // namespace rtc + +#endif // WEBRTC_RTC_BASE_OPENSSLADAPTER_H_ diff --git a/webrtc/rtc_base/ssladapter.cc b/webrtc/rtc_base/ssladapter.cc index 07a13b5b23..f26ebdab4b 100644 --- a/webrtc/rtc_base/ssladapter.cc +++ b/webrtc/rtc_base/ssladapter.cc @@ -16,8 +16,11 @@ namespace rtc { -SSLAdapter* -SSLAdapter::Create(AsyncSocket* socket) { +SSLAdapterFactory* SSLAdapterFactory::Create() { + return new OpenSSLAdapterFactory(); +} + +SSLAdapter* SSLAdapter::Create(AsyncSocket* socket) { return new OpenSSLAdapter(socket); } diff --git a/webrtc/rtc_base/ssladapter.h b/webrtc/rtc_base/ssladapter.h index dccb6d186d..6b12035d2f 100644 --- a/webrtc/rtc_base/ssladapter.h +++ b/webrtc/rtc_base/ssladapter.h @@ -16,13 +16,37 @@ namespace rtc { -/////////////////////////////////////////////////////////////////////////////// +class SSLAdapter; +// Class for creating SSL adapters with shared state, e.g., a session cache, +// which allows clients to resume SSL sessions to previously-contacted hosts. +// Clients should create the factory using Create(), set up the factory as +// needed using SetMode, and then call CreateAdapter to create adapters when +// needed. +class SSLAdapterFactory { + public: + virtual ~SSLAdapterFactory() {} + // Specifies whether TLS or DTLS is to be used for the SSL adapters. + virtual void SetMode(SSLMode mode) = 0; + // Creates a new SSL adapter, but from a shared context. + virtual SSLAdapter* CreateAdapter(AsyncSocket* socket) = 0; + + static SSLAdapterFactory* Create(); +}; + +// Class that abstracts a client-to-server SSL session. It can be created +// standalone, via SSLAdapter::Create, or through a factory as described above, +// in which case it will share state with other SSLAdapters created from the +// same factory. +// After creation, call StartSSL to initiate the SSL handshake to the server. class SSLAdapter : public AsyncSocketAdapter { public: - explicit SSLAdapter(AsyncSocket* socket) - : AsyncSocketAdapter(socket), ignore_bad_cert_(false) { } + explicit SSLAdapter(AsyncSocket* socket) : AsyncSocketAdapter(socket) {} + // Methods that control server certificate verification, used in unit tests. + // Do not call these methods in production code. + // TODO(juberti): Remove the opportunistic encryption mechanism in + // BasicPacketSocketFactory that uses this function. bool ignore_bad_cert() const { return ignore_bad_cert_; } void set_ignore_bad_cert(bool ignore) { ignore_bad_cert_ = ignore; } @@ -32,7 +56,15 @@ class SSLAdapter : public AsyncSocketAdapter { // StartSSL returns 0 if successful. // If StartSSL is called while the socket is closed or connecting, the SSL // negotiation will begin as soon as the socket connects. - virtual int StartSSL(const char* hostname, bool restartable) = 0; + // TODO(juberti): Remove |restartable|. + virtual int StartSSL(const char* hostname, bool restartable = false) = 0; + + // When an SSLAdapterFactory is used, an SSLAdapter may be used to resume + // a previous SSL session, which results in an abbreviated handshake. + // This method, if called after SSL has been established for this adapter, + // indicates whether the current session is a resumption of a previous + // session. + virtual bool IsResumedSession() = 0; // Create the default SSL adapter for this platform. On failure, returns null // and deletes |socket|. Otherwise, the returned SSLAdapter takes ownership @@ -41,7 +73,7 @@ class SSLAdapter : public AsyncSocketAdapter { private: // If true, the server certificate need not match the configured hostname. - bool ignore_bad_cert_; + bool ignore_bad_cert_ = false; }; /////////////////////////////////////////////////////////////////////////////// @@ -58,8 +90,6 @@ bool InitializeSSLThread(); // Call to cleanup additional threads, and also the main thread. bool CleanupSSL(); -/////////////////////////////////////////////////////////////////////////////// - } // namespace rtc #endif // WEBRTC_RTC_BASE_SSLADAPTER_H_