From c08de0f4b7c8593fdecde2e5b1bbd9adecffb89b Mon Sep 17 00:00:00 2001 From: Zhi Huang Date: Mon, 11 Dec 2017 00:20:23 -0800 Subject: [PATCH] Allow the TransportController to create RTP level transports. Add methods to TransportController so that it can create RTP level transports (RtpTransport, SrtpTransport and DtlsSrtpTransport.). The RTP level transports are ref-counted since they could be shared by multiple BaseChannels and TransportController manages the life time of the transports. Bug: webrtc:7013 Change-Id: Ifd31062426e513d95473e257c9c9ff84a8c537fd Reviewed-on: https://webrtc-review.googlesource.com/5080 Commit-Queue: Zhi Huang Reviewed-by: Peter Thatcher Reviewed-by: Steve Anton Cr-Commit-Position: refs/heads/master@{#21196} --- pc/transportcontroller.cc | 135 +++++++++++++++++++++++++++++ pc/transportcontroller.h | 37 ++++++++ pc/transportcontroller_unittest.cc | 105 ++++++++++++++++++++++ 3 files changed, 277 insertions(+) diff --git a/pc/transportcontroller.cc b/pc/transportcontroller.cc index ce49af3594..a9b030cfe9 100644 --- a/pc/transportcontroller.cc +++ b/pc/transportcontroller.cc @@ -12,10 +12,12 @@ #include #include +#include #include "p2p/base/port.h" #include "rtc_base/bind.h" #include "rtc_base/checks.h" +#include "rtc_base/ptr_util.h" #include "rtc_base/thread.h" using webrtc::SdpType; @@ -329,6 +331,133 @@ void TransportController::DestroyDtlsTransport_n( UpdateAggregateStates_n(); } +webrtc::SrtpTransport* TransportController::CreateSdesTransport( + const std::string& transport_name, + bool rtcp_mux_enabled) { + if (!network_thread_->IsCurrent()) { + return network_thread_->Invoke(RTC_FROM_HERE, [&] { + return CreateSdesTransport(transport_name, rtcp_mux_enabled); + }); + } + + auto existing_rtp_transport = FindRtpTransport(transport_name); + + if (existing_rtp_transport) { + // For SRTP transport wrapper, the |srtp_transport| is expected to be + // non-null and |dtls_srtp_transport| is expected to be a nullptr. + if (!existing_rtp_transport->srtp_transport || + existing_rtp_transport->dtls_srtp_transport) { + RTC_LOG(LS_ERROR) + << "Failed to create an RTP transport for SDES using name: " + << transport_name << " because the type doesn't match."; + return nullptr; + } + existing_rtp_transport->AddRef(); + return existing_rtp_transport->srtp_transport; + } + + auto new_srtp_transport = + rtc::MakeUnique(rtcp_mux_enabled); + + // The SDES should use an IceTransport rather than a DtlsTransport. We call + // |CreateDtlsTransport_n| here because the DtlsTransport will downgrade to an + // wrapper over IceTransport if we don't set the certificates and it will just + // forward the packets and signals without using DTLS. The support of SDES + // will be removed once all the downstream application stop using it. + new_srtp_transport->SetRtpPacketTransport(CreateDtlsTransport_n( + transport_name, cricket::ICE_CANDIDATE_COMPONENT_RTP)); + if (!rtcp_mux_enabled) { + new_srtp_transport->SetRtcpPacketTransport(CreateDtlsTransport_n( + transport_name, cricket::ICE_CANDIDATE_COMPONENT_RTCP)); + } + +#if defined(ENABLE_EXTERNAL_AUTH) + new_srtp_transport->EnableExternalAuth(); +#endif + + auto new_rtp_transport_wrapper = new RefCountedRtpTransport(); + new_rtp_transport_wrapper->srtp_transport = new_srtp_transport.get(); + new_rtp_transport_wrapper->rtp_transport = std::move(new_srtp_transport); + new_rtp_transport_wrapper->AddRef(); + rtp_transports_[transport_name] = new_rtp_transport_wrapper; + return rtp_transports_[transport_name]->srtp_transport; +} + +webrtc::DtlsSrtpTransport* TransportController::CreateDtlsSrtpTransport( + const std::string& transport_name, + bool rtcp_mux_enabled) { + if (!network_thread_->IsCurrent()) { + return network_thread_->Invoke( + RTC_FROM_HERE, [&] { + return CreateDtlsSrtpTransport(transport_name, rtcp_mux_enabled); + }); + } + auto existing_rtp_transport = FindRtpTransport(transport_name); + + if (existing_rtp_transport) { + // For DTLS-SRTP transport wrapper, the |dtls_srtp_transport| is expected to + // be non-null and |srtp_transport| is expected to be a nullptr. + if (existing_rtp_transport->srtp_transport || + !existing_rtp_transport->dtls_srtp_transport) { + RTC_LOG(LS_ERROR) + << "Failed to create an RTP transport for DTLS-SRTP using name: " + << transport_name << " because the type doesn't match."; + return nullptr; + } + existing_rtp_transport->AddRef(); + return existing_rtp_transport->dtls_srtp_transport; + } + + auto new_srtp_transport = + rtc::MakeUnique(rtcp_mux_enabled); + +#if defined(ENABLE_EXTERNAL_AUTH) + new_srtp_transport->EnableExternalAuth(); +#endif + + auto new_dtls_srtp_transport = + rtc::MakeUnique(std::move(new_srtp_transport)); + + auto rtp_dtls_transport = CreateDtlsTransport_n( + transport_name, cricket::ICE_CANDIDATE_COMPONENT_RTP); + auto rtcp_dtls_transport = + rtcp_mux_enabled + ? nullptr + : CreateDtlsTransport_n(transport_name, + cricket::ICE_CANDIDATE_COMPONENT_RTCP); + + new_dtls_srtp_transport->SetDtlsTransports(rtp_dtls_transport, + rtcp_dtls_transport); + + auto new_rtp_transport_wrapper = new RefCountedRtpTransport(); + new_rtp_transport_wrapper->dtls_srtp_transport = + new_dtls_srtp_transport.get(); + new_rtp_transport_wrapper->rtp_transport = std::move(new_dtls_srtp_transport); + new_rtp_transport_wrapper->AddRef(); + rtp_transports_[transport_name] = new_rtp_transport_wrapper; + return rtp_transports_[transport_name]->dtls_srtp_transport; +} + +void TransportController::DestroyTransport(const std::string& transport_name) { + if (!network_thread_->IsCurrent()) { + network_thread_->Invoke(RTC_FROM_HERE, + [&] { DestroyTransport(transport_name); }); + return; + } + + auto existing_rtp_transport = FindRtpTransport(transport_name); + if (!existing_rtp_transport) { + RTC_LOG(LS_WARNING) << "Attempting to delete " << transport_name + << " transport , which doesn't exist."; + return; + } + if (existing_rtp_transport->Release() == + rtc::RefCountReleaseStatus::kDroppedLastRef) { + rtp_transports_.erase(transport_name); + } + return; +} + std::vector TransportController::transport_names_for_testing() { std::vector ret; for (const auto& kv : transports_) { @@ -404,6 +533,12 @@ void TransportController::OnMessage(rtc::Message* pmsg) { } } +const TransportController::RefCountedRtpTransport* +TransportController::FindRtpTransport(const std::string& transport_name) { + auto it = rtp_transports_.find(transport_name); + return it == rtp_transports_.end() ? nullptr : it->second; +} + std::vector::iterator TransportController::GetChannelIterator_n(const std::string& transport_name, int component) { diff --git a/pc/transportcontroller.h b/pc/transportcontroller.h index 51f870e6a8..bdde17258b 100644 --- a/pc/transportcontroller.h +++ b/pc/transportcontroller.h @@ -20,6 +20,9 @@ #include "p2p/base/dtlstransport.h" #include "p2p/base/jseptransport.h" #include "p2p/base/p2ptransportchannel.h" +#include "pc/dtlssrtptransport.h" +#include "pc/rtptransport.h" +#include "pc/srtptransport.h" #include "rtc_base/asyncinvoker.h" #include "rtc_base/constructormagic.h" #include "rtc_base/refcountedobject.h" @@ -127,6 +130,20 @@ class TransportController : public sigslot::has_slots<>, virtual void DestroyDtlsTransport_n(const std::string& transport_name, int component); + // Create an SrtpTransport/DtlsSrtpTransport if it doesn't exist. + // Otherwise, increments a reference count and returns the existing one. + // These methods are not currently used but the plan is to transition + // PeerConnection and BaseChannel to use them instead of CreateDtlsTransport. + webrtc::SrtpTransport* CreateSdesTransport(const std::string& transport_name, + bool rtcp_mux_enabled); + webrtc::DtlsSrtpTransport* CreateDtlsSrtpTransport( + const std::string& transport_name, + bool rtcp_mux_enabled); + + // Destroy an RTP level transport which can be an RtpTransport, an + // SrtpTransport or a DtlsSrtpTransport. + void DestroyTransport(const std::string& transport_name); + // TODO(deadbeef): Remove all for_testing methods! const rtc::scoped_refptr& certificate_for_testing() const { @@ -180,6 +197,24 @@ class TransportController : public sigslot::has_slots<>, class ChannelPair; typedef rtc::RefCountedObject RefCountedChannel; + // Wrapper for RtpTransport that keeps a reference count. + // When using SDES, |srtp_transport| is non-null, |dtls_srtp_transport| is + // null and |rtp_transport.get()| == |srtp_transport|, + // When using DTLS-SRTP, |dtls_srtp_transport| is non-null, |srtp_transport| + // is null and |rtp_transport.get()| == |dtls_srtp_transport|, + // When using unencrypted RTP, only |rtp_transport| is non-null. + struct RtpTransportWrapper { + // |rtp_transport| is always non-null. + std::unique_ptr rtp_transport; + webrtc::SrtpTransport* srtp_transport = nullptr; + webrtc::DtlsSrtpTransport* dtls_srtp_transport = nullptr; + }; + + typedef rtc::RefCountedObject RefCountedRtpTransport; + + const RefCountedRtpTransport* FindRtpTransport( + const std::string& transport_name); + // Helper functions to get a channel or transport, or iterator to it (in case // it needs to be erased). std::vector::iterator GetChannelIterator_n( @@ -251,6 +286,8 @@ class TransportController : public sigslot::has_slots<>, std::map> transports_; std::vector channels_; + std::map rtp_transports_; + // Aggregate state for TransportChannelImpls. IceConnectionState connection_state_ = kIceConnectionConnecting; bool receiving_ = false; diff --git a/pc/transportcontroller_unittest.cc b/pc/transportcontroller_unittest.cc index f6564c5c60..838e277a03 100644 --- a/pc/transportcontroller_unittest.cc +++ b/pc/transportcontroller_unittest.cc @@ -22,6 +22,7 @@ #include "rtc_base/helpers.h" #include "rtc_base/sslidentity.h" #include "rtc_base/thread.h" +#include "test/gtest.h" using webrtc::SdpType; @@ -32,6 +33,7 @@ static const char kIceUfrag2[] = "TESTICEUFRAG0002"; static const char kIcePwd2[] = "TESTICEPWD00000000000002"; static const char kIceUfrag3[] = "TESTICEUFRAG0003"; static const char kIcePwd3[] = "TESTICEPWD00000000000003"; +static const bool kRtcpMuxEnabled = true; namespace cricket { @@ -908,4 +910,107 @@ TEST_F(TransportControllerTest, NeedsIceRestart) { EXPECT_TRUE(transport_controller_->NeedsIceRestart("video")); } +enum class RTPTransportType { kSdes, kDtlsSrtp }; +std::ostream& operator<<(std::ostream& out, RTPTransportType value) { + switch (value) { + case RTPTransportType::kSdes: + return out << "SDES"; + case RTPTransportType::kDtlsSrtp: + return out << "DTLS-SRTP"; + } + return out; +} + +// Tests the TransportController can correctly create and destroy the RTP +// transports. +class TransportControllerRTPTransportTest + : public TransportControllerTest, + public ::testing::WithParamInterface { + protected: + // Helper function used to create an RTP layer transport. + webrtc::RtpTransportInternal* CreateRtpTransport( + const std::string& transport_name) { + RTPTransportType type = GetParam(); + switch (type) { + case RTPTransportType::kSdes: + return transport_controller_->CreateSdesTransport(transport_name, + kRtcpMuxEnabled); + case RTPTransportType::kDtlsSrtp: + return transport_controller_->CreateDtlsSrtpTransport(transport_name, + kRtcpMuxEnabled); + } + return nullptr; + } +}; + +// Tests that creating transports with the same name will cause the +// second call to re-use the transport created in the first call. +TEST_P(TransportControllerRTPTransportTest, CreateTransportsWithReference) { + const std::string transport_name = "transport"; + webrtc::RtpTransportInternal* transport1 = CreateRtpTransport(transport_name); + webrtc::RtpTransportInternal* transport2 = CreateRtpTransport(transport_name); + EXPECT_NE(nullptr, transport1); + EXPECT_NE(nullptr, transport2); + // The TransportController is expected to return the existing one when using + // the same transport name. + EXPECT_EQ(transport1, transport2); + transport_controller_->DestroyTransport(transport_name); + transport_controller_->DestroyTransport(transport_name); +} + +// Tests that creating different type of RTP transports with same name is not +// allowed. +TEST_P(TransportControllerRTPTransportTest, + CreateDifferentTypeOfTransportsWithSameName) { + const std::string transport_name = "transport"; + webrtc::RtpTransportInternal* transport1 = CreateRtpTransport(transport_name); + EXPECT_NE(nullptr, transport1); + RTPTransportType type = GetParam(); + switch (type) { + case RTPTransportType::kSdes: + EXPECT_EQ(nullptr, transport_controller_->CreateDtlsSrtpTransport( + transport_name, kRtcpMuxEnabled)); + break; + case RTPTransportType::kDtlsSrtp: + EXPECT_EQ(nullptr, transport_controller_->CreateSdesTransport( + transport_name, kRtcpMuxEnabled)); + break; + default: + ASSERT_TRUE(false); + } + transport_controller_->DestroyTransport(transport_name); +} + +// Tests the RTP transport is not actually destroyed if references still exist. +TEST_P(TransportControllerRTPTransportTest, DestroyTransportWithReference) { + const std::string transport_name = "transport"; + webrtc::RtpTransportInternal* transport1 = CreateRtpTransport(transport_name); + webrtc::RtpTransportInternal* transport2 = CreateRtpTransport(transport_name); + EXPECT_NE(nullptr, transport1); + EXPECT_NE(nullptr, transport2); + transport_controller_->DestroyTransport(transport_name); + EXPECT_NE(nullptr, transport1->rtp_packet_transport()); + EXPECT_EQ(nullptr, transport1->rtcp_packet_transport()); + transport_controller_->DestroyTransport(transport_name); +} + +// Tests the RTP is actually destroyed if there is no reference to it. +TEST_P(TransportControllerRTPTransportTest, DestroyTransportWithNoReference) { + const std::string transport_name = "transport"; + webrtc::RtpTransportInternal* transport1 = CreateRtpTransport(transport_name); + webrtc::RtpTransportInternal* transport2 = CreateRtpTransport(transport_name); + EXPECT_NE(nullptr, transport1); + EXPECT_NE(nullptr, transport2); + transport_controller_->DestroyTransport(transport_name); + transport_controller_->DestroyTransport(transport_name); +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + EXPECT_DEATH(transport1->IsWritable(false), /*error_message=*/""); +#endif +} + +INSTANTIATE_TEST_CASE_P(TransportControllerTest, + TransportControllerRTPTransportTest, + ::testing::Values(RTPTransportType::kSdes, + RTPTransportType::kDtlsSrtp)); + } // namespace cricket