From ac9a28827402346702fc97442ffaf3e4e3618fe4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=B6ller?= Date: Wed, 20 Oct 2021 15:25:09 +0200 Subject: [PATCH] Disable SSLAdapter methods Listen and Accept Only affects turn server. Refactored to wrap sockets with SSLAdapter after Accept, using the SSLAdapterFactory to hold needed configuration. Bug: webrtc:13065 Change-Id: I5df65aad5728d8d40d95b22db6398a573ec7a36f Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/235823 Reviewed-by: Harald Alvestrand Commit-Queue: Niels Moller Cr-Commit-Position: refs/heads/main@{#35258} --- p2p/base/test_turn_server.h | 21 +++++++++++-------- p2p/base/turn_server.cc | 19 +++++++++++++----- p2p/base/turn_server.h | 13 ++++++++++-- rtc_base/openssl_adapter.cc | 40 +++++++++++++++++++++---------------- rtc_base/openssl_adapter.h | 17 +++++++++++++++- rtc_base/ssl_adapter.cc | 4 ++-- rtc_base/ssl_adapter.h | 18 ++++++++++++++++- 7 files changed, 96 insertions(+), 36 deletions(-) diff --git a/p2p/base/test_turn_server.h b/p2p/base/test_turn_server.h index e1deb5901e..6cad13525f 100644 --- a/p2p/base/test_turn_server.h +++ b/p2p/base/test_turn_server.h @@ -11,7 +11,9 @@ #ifndef P2P_BASE_TEST_TURN_SERVER_H_ #define P2P_BASE_TEST_TURN_SERVER_H_ +#include #include +#include #include #include "api/sequence_checker.h" @@ -104,21 +106,24 @@ class TestTurnServer : public TurnAuthInterface { // new connections. rtc::Socket* socket = thread_->socketserver()->CreateSocket(AF_INET, SOCK_STREAM); + socket->Bind(int_addr); + socket->Listen(5); if (proto == cricket::PROTO_TLS) { // For TLS, wrap the TCP socket with an SSL adapter. The adapter must // be configured with a self-signed certificate for testing. // Additionally, the client will not present a valid certificate, so we // must not fail when checking the peer's identity. - rtc::SSLAdapter* adapter = rtc::SSLAdapter::Create(socket); - adapter->SetRole(rtc::SSL_SERVER); - adapter->SetIdentity( + std::unique_ptr ssl_adapter_factory = + rtc::SSLAdapterFactory::Create(); + ssl_adapter_factory->SetRole(rtc::SSL_SERVER); + ssl_adapter_factory->SetIdentity( rtc::SSLIdentity::Create(common_name, rtc::KeyParams())); - adapter->SetIgnoreBadCert(ignore_bad_cert); - socket = adapter; + ssl_adapter_factory->SetIgnoreBadCert(ignore_bad_cert); + server_.AddInternalServerSocket(socket, proto, + std::move(ssl_adapter_factory)); + } else { + server_.AddInternalServerSocket(socket, proto); } - socket->Bind(int_addr); - socket->Listen(5); - server_.AddInternalServerSocket(socket, proto); } else { RTC_NOTREACHED() << "Unknown protocol type: " << proto; } diff --git a/p2p/base/turn_server.cc b/p2p/base/turn_server.cc index fd9cd16138..5685e20876 100644 --- a/p2p/base/turn_server.cc +++ b/p2p/base/turn_server.cc @@ -152,12 +152,15 @@ void TurnServer::AddInternalSocket(rtc::AsyncPacketSocket* socket, socket->SignalReadPacket.connect(this, &TurnServer::OnInternalPacket); } -void TurnServer::AddInternalServerSocket(rtc::Socket* socket, - ProtocolType proto) { +void TurnServer::AddInternalServerSocket( + rtc::Socket* socket, + ProtocolType proto, + std::unique_ptr ssl_adapter_factory) { RTC_DCHECK_RUN_ON(thread_); + RTC_DCHECK(server_listen_sockets_.end() == server_listen_sockets_.find(socket)); - server_listen_sockets_[socket] = proto; + server_listen_sockets_[socket] = {proto, std::move(ssl_adapter_factory)}; socket->SignalReadEvent.connect(this, &TurnServer::OnNewInternalConnection); } @@ -181,13 +184,19 @@ void TurnServer::AcceptConnection(rtc::Socket* server_socket) { rtc::SocketAddress accept_addr; rtc::Socket* accepted_socket = server_socket->Accept(&accept_addr); if (accepted_socket != NULL) { - ProtocolType proto = server_listen_sockets_[server_socket]; + const ServerSocketInfo& info = server_listen_sockets_[server_socket]; + if (info.ssl_adapter_factory) { + rtc::SSLAdapter* ssl_adapter = + info.ssl_adapter_factory->CreateAdapter(accepted_socket); + ssl_adapter->StartSSL(""); + accepted_socket = ssl_adapter; + } cricket::AsyncStunTCPSocket* tcp_socket = new cricket::AsyncStunTCPSocket(accepted_socket); tcp_socket->SignalClose.connect(this, &TurnServer::OnInternalSocketClose); // Finally add the socket so it can start communicating with the client. - AddInternalSocket(tcp_socket, proto); + AddInternalSocket(tcp_socket, info.proto); } } diff --git a/p2p/base/turn_server.h b/p2p/base/turn_server.h index 7942c09af9..481b081172 100644 --- a/p2p/base/turn_server.h +++ b/p2p/base/turn_server.h @@ -23,6 +23,7 @@ #include "p2p/base/port_interface.h" #include "rtc_base/async_packet_socket.h" #include "rtc_base/socket_address.h" +#include "rtc_base/ssl_adapter.h" #include "rtc_base/third_party/sigslot/sigslot.h" #include "rtc_base/thread.h" @@ -237,7 +238,10 @@ class TurnServer : public sigslot::has_slots<> { // Starts listening for the connections on this socket. When someone tries // to connect, the connection will be accepted and a new internal socket // will be added. - void AddInternalServerSocket(rtc::Socket* socket, ProtocolType proto); + void AddInternalServerSocket( + rtc::Socket* socket, + ProtocolType proto, + std::unique_ptr ssl_adapter_factory = nullptr); // Specifies the factory to use for creating external sockets. void SetExternalSocketFactory(rtc::PacketSocketFactory* factory, const rtc::SocketAddress& address); @@ -320,7 +324,12 @@ class TurnServer : public sigslot::has_slots<> { RTC_RUN_ON(thread_); typedef std::map InternalSocketMap; - typedef std::map ServerSocketMap; + struct ServerSocketInfo { + ProtocolType proto; + // If non-null, used to wrap accepted sockets. + std::unique_ptr ssl_adapter_factory; + }; + typedef std::map ServerSocketMap; rtc::Thread* const thread_; const std::string nonce_key_; diff --git a/rtc_base/openssl_adapter.cc b/rtc_base/openssl_adapter.cc index 7489bc992d..bc10e619eb 100644 --- a/rtc_base/openssl_adapter.cc +++ b/rtc_base/openssl_adapter.cc @@ -250,21 +250,6 @@ void OpenSSLAdapter::SetRole(SSLRole role) { role_ = role; } -Socket* OpenSSLAdapter::Accept(SocketAddress* paddr) { - RTC_DCHECK(role_ == SSL_SERVER); - Socket* socket = SSLAdapter::Accept(paddr); - if (!socket) { - return nullptr; - } - - SSLAdapter* adapter = SSLAdapter::Create(socket); - adapter->SetIdentity(identity_->Clone()); - adapter->SetRole(rtc::SSL_SERVER); - adapter->SetIgnoreBadCert(ignore_bad_cert_); - adapter->StartSSL(""); - return adapter; -} - int OpenSSLAdapter::StartSSL(const char* hostname) { if (state_ != SSL_NONE) return -1; @@ -1038,6 +1023,21 @@ void OpenSSLAdapterFactory::SetCertVerifier( ssl_cert_verifier_ = ssl_cert_verifier; } +void OpenSSLAdapterFactory::SetIdentity(std::unique_ptr identity) { + RTC_DCHECK(!ssl_session_cache_); + identity_ = std::move(identity); +} + +void OpenSSLAdapterFactory::SetRole(SSLRole role) { + RTC_DCHECK(!ssl_session_cache_); + ssl_role_ = role; +} + +void OpenSSLAdapterFactory::SetIgnoreBadCert(bool ignore) { + RTC_DCHECK(!ssl_session_cache_); + ignore_bad_cert_ = ignore; +} + OpenSSLAdapter* OpenSSLAdapterFactory::CreateAdapter(Socket* socket) { if (ssl_session_cache_ == nullptr) { SSL_CTX* ssl_ctx = OpenSSLAdapter::CreateContext(ssl_mode_, true); @@ -1049,8 +1049,14 @@ OpenSSLAdapter* OpenSSLAdapterFactory::CreateAdapter(Socket* socket) { std::make_unique(ssl_mode_, ssl_ctx); SSL_CTX_free(ssl_ctx); } - return new OpenSSLAdapter(socket, ssl_session_cache_.get(), - ssl_cert_verifier_); + OpenSSLAdapter* ssl_adapter = + new OpenSSLAdapter(socket, ssl_session_cache_.get(), ssl_cert_verifier_); + ssl_adapter->SetRole(ssl_role_); + ssl_adapter->SetIgnoreBadCert(ignore_bad_cert_); + if (identity_) { + ssl_adapter->SetIdentity(identity_->Clone()); + } + return ssl_adapter; } OpenSSLAdapter::EarlyExitCatcher::EarlyExitCatcher(OpenSSLAdapter& adapter_ptr) diff --git a/rtc_base/openssl_adapter.h b/rtc_base/openssl_adapter.h index 266ed35421..7e1f87b8ab 100644 --- a/rtc_base/openssl_adapter.h +++ b/rtc_base/openssl_adapter.h @@ -60,7 +60,6 @@ class OpenSSLAdapter final : public SSLAdapter, void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) override; void SetIdentity(std::unique_ptr identity) override; void SetRole(SSLRole role) override; - Socket* Accept(SocketAddress* paddr) override; int StartSSL(const char* hostname) override; int Send(const void* pv, size_t cb) override; int SendTo(const void* pv, size_t cb, const SocketAddress& addr) override; @@ -191,10 +190,21 @@ class OpenSSLAdapterFactory : public SSLAdapterFactory { // the first adapter is created with the factory. If it is called after it // will DCHECK. void SetMode(SSLMode mode) override; + // Set a custom certificate verifier to be passed down to each instance // created with this factory. This should only ever be set before the first // call to the factory and cannot be changed after the fact. void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) override; + + void SetIdentity(std::unique_ptr identity) override; + + // Choose whether the socket acts as a server socket or client socket. + void SetRole(SSLRole role) override; + + // Methods that control server certificate verification, used in unit tests. + // Do not call these methods in production code. + void SetIgnoreBadCert(bool ignore) override; + // Constructs a new socket using the shared OpenSSLSessionCache. This means // existing SSLSessions already in the cache will be reused instead of // re-created for improved performance. @@ -203,6 +213,11 @@ class OpenSSLAdapterFactory : public SSLAdapterFactory { private: // Holds the SSLMode (DTLS,TLS) that will be used to set the session cache. SSLMode ssl_mode_ = SSL_MODE_TLS; + SSLRole ssl_role_ = SSL_CLIENT; + bool ignore_bad_cert_ = false; + + std::unique_ptr identity_; + // Holds a cache of existing SSL Sessions. std::unique_ptr ssl_session_cache_; // Provides an optional custom callback for verifying SSL certificates, this diff --git a/rtc_base/ssl_adapter.cc b/rtc_base/ssl_adapter.cc index c9b54c4dc9..ff936a79fb 100644 --- a/rtc_base/ssl_adapter.cc +++ b/rtc_base/ssl_adapter.cc @@ -16,8 +16,8 @@ namespace rtc { -SSLAdapterFactory* SSLAdapterFactory::Create() { - return new OpenSSLAdapterFactory(); +std::unique_ptr SSLAdapterFactory::Create() { + return std::make_unique(); } SSLAdapter* SSLAdapter::Create(Socket* socket) { diff --git a/rtc_base/ssl_adapter.h b/rtc_base/ssl_adapter.h index 1f0616bffc..8f98141651 100644 --- a/rtc_base/ssl_adapter.h +++ b/rtc_base/ssl_adapter.h @@ -39,10 +39,21 @@ class SSLAdapterFactory { // Specify a custom certificate verifier for SSL. virtual void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) = 0; + // Set the certificate this socket will present to incoming clients. + // Takes ownership of `identity`. + virtual void SetIdentity(std::unique_ptr identity) = 0; + + // Choose whether the socket acts as a server socket or client socket. + virtual void SetRole(SSLRole role) = 0; + + // Methods that control server certificate verification, used in unit tests. + // Do not call these methods in production code. + virtual void SetIgnoreBadCert(bool ignore) = 0; + // Creates a new SSL adapter, but from a shared context. virtual SSLAdapter* CreateAdapter(Socket* socket) = 0; - static SSLAdapterFactory* Create(); + static std::unique_ptr Create(); }; // Class that abstracts a client-to-server SSL session. It can be created @@ -91,6 +102,11 @@ class SSLAdapter : public AsyncSocketAdapter { // and deletes `socket`. Otherwise, the returned SSLAdapter takes ownership // of `socket`. static SSLAdapter* Create(Socket* socket); + + private: + // Not supported. + int Listen(int backlog) override { RTC_CHECK(false); } + Socket* Accept(SocketAddress* paddr) override { RTC_CHECK(false); } }; ///////////////////////////////////////////////////////////////////////////////