diff --git a/media/base/rtputils.cc b/media/base/rtputils.cc index 7cf2c1ba7e..d0ba1cf72b 100644 --- a/media/base/rtputils.cc +++ b/media/base/rtputils.cc @@ -275,16 +275,6 @@ bool IsRtpPacket(const void* data, size_t len) { return (static_cast(data)[0] >> 6) == kRtpVersion; } -// Check the RTP payload type. If 63 < payload type < 96, it's RTCP. -// For additional details, see http://tools.ietf.org/html/rfc5761. -bool IsRtcp(const char* data, int len) { - if (len < 2) { - return false; - } - char pt = data[1] & 0x7F; - return (63 < pt) && (pt < 96); -} - bool IsValidRtpPayloadType(int payload_type) { return payload_type >= 0 && payload_type <= 127; } diff --git a/media/base/rtputils.h b/media/base/rtputils.h index 531a2cfeb1..0b7205cf8f 100644 --- a/media/base/rtputils.h +++ b/media/base/rtputils.h @@ -55,7 +55,6 @@ bool SetRtpHeader(void* data, size_t len, const RtpHeader& header); bool IsRtpPacket(const void* data, size_t len); -bool IsRtcp(const char* data, int len); // True if |payload type| is 0-127. bool IsValidRtpPayloadType(int payload_type); diff --git a/pc/BUILD.gn b/pc/BUILD.gn index 83bc88bd48..907d199168 100644 --- a/pc/BUILD.gn +++ b/pc/BUILD.gn @@ -30,6 +30,8 @@ rtc_static_library("rtc_pc_base") { defines = [] sources = [ "audiomonitor.h", + "bundlefilter.cc", + "bundlefilter.h", "channel.cc", "channel.h", "channelmanager.cc", @@ -78,13 +80,10 @@ rtc_static_library("rtc_pc_base") { "../api:optional", "../api:ortc_api", "../api:video_frame_api", - "../call:rtp_interfaces", - "../call:rtp_receiver", "../common_video:common_video", "../media:rtc_data", "../media:rtc_h264_profile_id", "../media:rtc_media_base", - "../modules/rtp_rtcp:rtp_rtcp_format", "../p2p:rtc_p2p", "../rtc_base:checks", "../rtc_base:rtc_base", @@ -275,6 +274,7 @@ if (rtc_include_tests) { testonly = true sources = [ + "bundlefilter_unittest.cc", "channel_unittest.cc", "channelmanager_unittest.cc", "currentspeakermonitor_unittest.cc", @@ -314,11 +314,9 @@ if (rtc_include_tests) { "../api:array_view", "../api:fakemetricsobserver", "../api:libjingle_peerconnection_api", - "../call:rtp_interfaces", "../logging:rtc_event_log_api", "../media:rtc_media_base", "../media:rtc_media_tests_utils", - "../modules/rtp_rtcp:rtp_rtcp_format", "../p2p:p2p_test_utils", "../p2p:rtc_p2p", "../rtc_base:checks", diff --git a/pc/bundlefilter.cc b/pc/bundlefilter.cc new file mode 100644 index 0000000000..7791da6274 --- /dev/null +++ b/pc/bundlefilter.cc @@ -0,0 +1,49 @@ +/* + * Copyright 2004 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "pc/bundlefilter.h" + +#include "media/base/rtputils.h" +#include "rtc_base/logging.h" + +namespace cricket { + +BundleFilter::BundleFilter() { +} + +BundleFilter::~BundleFilter() { +} + +bool BundleFilter::DemuxPacket(const uint8_t* data, size_t len) { + // For RTP packets, we check whether the payload type can be found. + if (!IsRtpPacket(data, len)) { + return false; + } + + int payload_type = 0; + if (!GetRtpPayloadType(data, len, &payload_type)) { + return false; + } + return FindPayloadType(payload_type); +} + +void BundleFilter::AddPayloadType(int payload_type) { + payload_types_.insert(payload_type); +} + +bool BundleFilter::FindPayloadType(int pl_type) const { + return payload_types_.find(pl_type) != payload_types_.end(); +} + +void BundleFilter::ClearAllPayloadTypes() { + payload_types_.clear(); +} + +} // namespace cricket diff --git a/pc/bundlefilter.h b/pc/bundlefilter.h new file mode 100644 index 0000000000..7decbba8a4 --- /dev/null +++ b/pc/bundlefilter.h @@ -0,0 +1,54 @@ +/* + * Copyright 2004 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef PC_BUNDLEFILTER_H_ +#define PC_BUNDLEFILTER_H_ + +#include + +#include +#include + +#include "media/base/streamparams.h" +#include "rtc_base/basictypes.h" + +namespace cricket { + +// In case of single RTP session and single transport channel, all session +// (or media) channels share a common transport channel. Hence they all get +// SignalReadPacket when packet received on transport channel. This requires +// cricket::BaseChannel to know all the valid sources, else media channel +// will decode invalid packets. +// +// This class determines whether a packet is destined for cricket::BaseChannel. +// This is only to be used for RTP packets as RTCP packets are not filtered. +// For RTP packets, this is decided based on the payload type. +class BundleFilter { + public: + BundleFilter(); + ~BundleFilter(); + + // Determines if a RTP packet belongs to valid cricket::BaseChannel. + bool DemuxPacket(const uint8_t* data, size_t len); + + // Adds the supported payload type. + void AddPayloadType(int payload_type); + + // Public for unittests. + bool FindPayloadType(int pl_type) const; + void ClearAllPayloadTypes(); + + private: + std::set payload_types_; +}; + +} // namespace cricket + +#endif // PC_BUNDLEFILTER_H_ diff --git a/pc/bundlefilter_unittest.cc b/pc/bundlefilter_unittest.cc new file mode 100644 index 0000000000..2b1af5c3a3 --- /dev/null +++ b/pc/bundlefilter_unittest.cc @@ -0,0 +1,72 @@ +/* + * Copyright 2004 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "pc/bundlefilter.h" +#include "rtc_base/gunit.h" + +using cricket::StreamParams; + +static const int kPayloadType1 = 0x11; +static const int kPayloadType2 = 0x22; +static const int kPayloadType3 = 0x33; + +// SSRC = 0x1111, Payload type = 0x11 +static const unsigned char kRtpPacketPt1Ssrc1[] = { + 0x80, kPayloadType1, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, + 0x11, +}; + +// SSRC = 0x2222, Payload type = 0x22 +static const unsigned char kRtpPacketPt2Ssrc2[] = { + 0x80, 0x80 + kPayloadType2, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x22, 0x22, +}; + +// SSRC = 0x2222, Payload type = 0x33 +static const unsigned char kRtpPacketPt3Ssrc2[] = { + 0x80, kPayloadType3, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x22, + 0x22, +}; + +// An SCTP packet. +static const unsigned char kSctpPacket[] = { + 0x00, 0x01, 0x00, 0x01, + 0xff, 0xff, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x04, + 0x00, 0x00, 0x00, 0x00, +}; + +TEST(BundleFilterTest, RtpPacketTest) { + cricket::BundleFilter bundle_filter; + bundle_filter.AddPayloadType(kPayloadType1); + EXPECT_TRUE(bundle_filter.DemuxPacket(kRtpPacketPt1Ssrc1, + sizeof(kRtpPacketPt1Ssrc1))); + bundle_filter.AddPayloadType(kPayloadType2); + EXPECT_TRUE(bundle_filter.DemuxPacket(kRtpPacketPt2Ssrc2, + sizeof(kRtpPacketPt2Ssrc2))); + + // Payload type 0x33 is not added. + EXPECT_FALSE(bundle_filter.DemuxPacket(kRtpPacketPt3Ssrc2, + sizeof(kRtpPacketPt3Ssrc2))); + // Size is too small. + EXPECT_FALSE(bundle_filter.DemuxPacket(kRtpPacketPt1Ssrc1, 11)); + + bundle_filter.ClearAllPayloadTypes(); + EXPECT_FALSE(bundle_filter.DemuxPacket(kRtpPacketPt1Ssrc1, + sizeof(kRtpPacketPt1Ssrc1))); + EXPECT_FALSE(bundle_filter.DemuxPacket(kRtpPacketPt2Ssrc2, + sizeof(kRtpPacketPt2Ssrc2))); +} + +TEST(BundleFilterTest, InvalidRtpPacket) { + cricket::BundleFilter bundle_filter; + EXPECT_FALSE(bundle_filter.DemuxPacket(kSctpPacket, sizeof(kSctpPacket))); +} diff --git a/pc/channel.cc b/pc/channel.cc index e08e2b502d..358cb77e55 100644 --- a/pc/channel.cc +++ b/pc/channel.cc @@ -29,7 +29,6 @@ // Adding 'nogncheck' to disable the gn include headers check to support modular // WebRTC build targets. #include "media/engine/webrtcvoiceengine.h" // nogncheck -#include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "p2p/base/packettransportinternal.h" #include "pc/channelmanager.h" #include "pc/rtpmediautils.h" @@ -111,7 +110,6 @@ BaseChannel::BaseChannel(rtc::Thread* worker_thread, srtp_required_(srtp_required), media_channel_(std::move(media_channel)) { RTC_DCHECK_RUN_ON(worker_thread_); - demuxer_criteria_.mid = content_name; rtp_transport_ = unencrypted_rtp_transport_.get(); ConnectToRtpTransport(); RTC_LOG(LS_INFO) << "Created channel for " << content_name; @@ -133,12 +131,13 @@ BaseChannel::~BaseChannel() { void BaseChannel::ConnectToRtpTransport() { RTC_DCHECK(rtp_transport_); - bool success = RegisterRtpDemuxerSink(); - RTC_DCHECK(success); rtp_transport_->SignalReadyToSend.connect( this, &BaseChannel::OnTransportReadyToSend); - rtp_transport_->SignalRtcpPacketReceived.connect( - this, &BaseChannel::OnRtcpPacketReceived); + // TODO(zstein): RtpTransport::SignalPacketReceived will probably be replaced + // with a callback interface later so that the demuxer can select which + // channel to signal. + rtp_transport_->SignalPacketReceived.connect(this, + &BaseChannel::OnPacketReceived); rtp_transport_->SignalNetworkRouteChanged.connect( this, &BaseChannel::OnNetworkRouteChanged); rtp_transport_->SignalWritableState.connect(this, @@ -155,9 +154,8 @@ void BaseChannel::ConnectToRtpTransport() { void BaseChannel::DisconnectFromRtpTransport() { RTC_DCHECK(rtp_transport_); - rtp_transport_->UnregisterRtpDemuxerSink(this); rtp_transport_->SignalReadyToSend.disconnect(this); - rtp_transport_->SignalRtcpPacketReceived.disconnect(this); + rtp_transport_->SignalPacketReceived.disconnect(this); rtp_transport_->SignalNetworkRouteChanged.disconnect(this); rtp_transport_->SignalWritableState.disconnect(this); rtp_transport_->SignalSentPacket.disconnect(this); @@ -205,26 +203,17 @@ void BaseChannel::Deinit() { // functions, so need to stop this process in Deinit that is called in // derived classes destructor. network_thread_->Invoke(RTC_FROM_HERE, [&] { - if (rtp_transport_) { - FlushRtcpMessages_n(); - if (dtls_srtp_transport_) { - dtls_srtp_transport_->SetDtlsTransports(nullptr, nullptr); - } else { - rtp_transport_->SetRtpPacketTransport(nullptr); - rtp_transport_->SetRtcpPacketTransport(nullptr); - } - DisconnectFromRtpTransport(); - } + FlushRtcpMessages_n(); + if (dtls_srtp_transport_) { + dtls_srtp_transport_->SetDtlsTransports(nullptr, nullptr); + } else { + rtp_transport_->SetRtpPacketTransport(nullptr); + rtp_transport_->SetRtcpPacketTransport(nullptr); + } // Clear pending read packets/messages. network_thread_->Clear(&invoker_); network_thread_->Clear(this); - // Because RTP level transports are accessed from the |network_thread_|, - // it's safer to release them from the |network_thread_| as well. - unencrypted_rtp_transport_.reset(); - sdes_transport_.reset(); - dtls_srtp_transport_.reset(); - rtp_transport_ = nullptr; }); } @@ -235,10 +224,8 @@ void BaseChannel::SetRtpTransport(webrtc::RtpTransportInternal* rtp_transport) { return; }); } + RTC_DCHECK(rtp_transport); - if (rtp_transport == rtp_transport_) { - return; - } if (rtp_transport_) { DisconnectFromRtpTransport(); @@ -594,37 +581,12 @@ bool BaseChannel::SendPacket(bool rtcp, : rtp_transport_->SendRtpPacket(packet, options, PF_SRTP_BYPASS); } -void BaseChannel::OnRtpPacket(const webrtc::RtpPacketReceived& parsed_packet) { - // Reconstruct the PacketTime from the |parsed_packet|. - // RtpPacketReceived.arrival_time_ms = (PacketTime + 500) / 1000; - // Note: The |not_before| field is always 0 here. This field is not currently - // used, so it should be fine. - int64_t timestamp = -1; - if (parsed_packet.arrival_time_ms() > 0) { - timestamp = parsed_packet.arrival_time_ms() * 1000; - } - rtc::PacketTime packet_time(timestamp, /*not_before=*/0); - OnPacketReceived(/*rtcp=*/false, parsed_packet.Buffer(), packet_time); -} - -void BaseChannel::UpdateRtpHeaderExtensionMap( - const RtpHeaderExtensions& header_extensions) { - RTC_DCHECK(rtp_transport_); - rtp_transport_->UpdateRtpHeaderExtensionMap(header_extensions); -} - -bool BaseChannel::RegisterRtpDemuxerSink() { - RTC_DCHECK(rtp_transport_); - return rtp_transport_->RegisterRtpDemuxerSink(demuxer_criteria_, this); -} - -void BaseChannel::OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time) { - OnPacketReceived(/*rtcp=*/true, *packet, packet_time); +bool BaseChannel::HandlesPayloadType(int packet_type) const { + return rtp_transport_->HandlesPayloadType(packet_type); } void BaseChannel::OnPacketReceived(bool rtcp, - const rtc::CopyOnWriteBuffer& packet, + rtc::CopyOnWriteBuffer* packet, const rtc::PacketTime& packet_time) { if (!has_received_packet_ && !rtcp) { has_received_packet_ = true; @@ -651,7 +613,7 @@ void BaseChannel::OnPacketReceived(bool rtcp, invoker_.AsyncInvoke( RTC_FROM_HERE, worker_thread_, - Bind(&BaseChannel::ProcessPacket, this, rtcp, packet, packet_time)); + Bind(&BaseChannel::ProcessPacket, this, rtcp, *packet, packet_time)); } void BaseChannel::ProcessPacket(bool rtcp, @@ -793,17 +755,14 @@ void BaseChannel::EnableSdes_n() { // DtlsSrtpTransport and SrtpTransport shouldn't be enabled at the same // time. RTC_DCHECK(!dtls_srtp_transport_); - - sdes_transport_ = rtc::MakeUnique(rtcp_mux_required_); + RTC_DCHECK(unencrypted_rtp_transport_); + sdes_transport_ = rtc::MakeUnique( + std::move(unencrypted_rtp_transport_)); #if defined(ENABLE_EXTERNAL_AUTH) sdes_transport_->EnableExternalAuth(); #endif - sdes_transport_->SetRtpPacketTransport( - rtp_transport_->rtp_packet_transport()); - sdes_transport_->SetRtcpPacketTransport( - rtp_transport_->rtcp_packet_transport()); SetRtpTransport(sdes_transport_.get()); - RTC_LOG(LS_INFO) << "SrtpTransport is created for SDES."; + RTC_LOG(LS_INFO) << "Wrapping RtpTransport in SrtpTransport."; } void BaseChannel::EnableDtlsSrtp_n() { @@ -813,12 +772,15 @@ void BaseChannel::EnableDtlsSrtp_n() { // DtlsSrtpTransport and SrtpTransport shouldn't be enabled at the same // time. RTC_DCHECK(!sdes_transport_); + RTC_DCHECK(unencrypted_rtp_transport_); - dtls_srtp_transport_ = - rtc::MakeUnique(rtcp_mux_required_); + auto srtp_transport = rtc::MakeUnique( + std::move(unencrypted_rtp_transport_)); #if defined(ENABLE_EXTERNAL_AUTH) - dtls_srtp_transport_->EnableExternalAuth(); + srtp_transport->EnableExternalAuth(); #endif + dtls_srtp_transport_ = + rtc::MakeUnique(std::move(srtp_transport)); SetRtpTransport(dtls_srtp_transport_.get()); if (cached_send_extension_ids_) { @@ -834,7 +796,8 @@ void BaseChannel::EnableDtlsSrtp_n() { RTC_DCHECK(rtp_dtls_transport_); dtls_srtp_transport_->SetDtlsTransports(rtp_dtls_transport_, rtcp_dtls_transport_); - RTC_LOG(LS_INFO) << "DtlsSrtpTransport is created for DTLS-SRTP."; + + RTC_LOG(LS_INFO) << "Wrapping SrtpTransport in DtlsSrtpTransport."; } bool BaseChannel::SetSrtp_n(const std::vector& cryptos, @@ -1114,7 +1077,7 @@ void BaseChannel::OnMessage(rtc::Message *pmsg) { } void BaseChannel::AddHandledPayloadType(int payload_type) { - demuxer_criteria_.payload_types.insert(static_cast(payload_type)); + rtp_transport_->AddHandledPayloadType(payload_type); } void BaseChannel::FlushRtcpMessages_n() { @@ -1244,7 +1207,6 @@ bool VoiceChannel::SetLocalContent_w(const MediaContentDescription* content, RtpHeaderExtensions rtp_header_extensions = GetFilteredRtpHeaderExtensions(audio->rtp_header_extensions()); - UpdateRtpHeaderExtensionMap(rtp_header_extensions); if (!SetRtpTransportParameters(content, type, CS_LOCAL, rtp_header_extensions, error_desc)) { @@ -1258,16 +1220,9 @@ bool VoiceChannel::SetLocalContent_w(const MediaContentDescription* content, error_desc); return false; } - for (const AudioCodec& codec : audio->codecs()) { AddHandledPayloadType(codec.id); } - // Need to re-register the sink to update the handled payload. - if (!RegisterRtpDemuxerSink()) { - RTC_LOG(LS_ERROR) << "Failed to set up audio demuxing."; - return false; - } - last_recv_params_ = recv_params; // TODO(pthatcher): Move local streams into AudioSendParameters, and @@ -1302,7 +1257,6 @@ bool VoiceChannel::SetRemoteContent_w(const MediaContentDescription* content, RtpHeaderExtensions rtp_header_extensions = GetFilteredRtpHeaderExtensions(audio->rtp_header_extensions()); - UpdateRtpHeaderExtensionMap(rtp_header_extensions); if (!SetRtpTransportParameters(content, type, CS_REMOTE, rtp_header_extensions, error_desc)) { return false; @@ -1358,6 +1312,7 @@ VideoChannel::~VideoChannel() { TRACE_EVENT0("webrtc", "VideoChannel::~VideoChannel"); // this can't be done in the base class, since it calls a virtual DisableMedia_w(); + Deinit(); } @@ -1395,7 +1350,6 @@ bool VideoChannel::SetLocalContent_w(const MediaContentDescription* content, RtpHeaderExtensions rtp_header_extensions = GetFilteredRtpHeaderExtensions(video->rtp_header_extensions()); - UpdateRtpHeaderExtensionMap(rtp_header_extensions); if (!SetRtpTransportParameters(content, type, CS_LOCAL, rtp_header_extensions, error_desc)) { @@ -1409,16 +1363,9 @@ bool VideoChannel::SetLocalContent_w(const MediaContentDescription* content, error_desc); return false; } - for (const VideoCodec& codec : video->codecs()) { AddHandledPayloadType(codec.id); } - // Need to re-register the sink to update the handled payload. - if (!RegisterRtpDemuxerSink()) { - RTC_LOG(LS_ERROR) << "Failed to set up video demuxing."; - return false; - } - last_recv_params_ = recv_params; // TODO(pthatcher): Move local streams into VideoSendParameters, and @@ -1453,7 +1400,6 @@ bool VideoChannel::SetRemoteContent_w(const MediaContentDescription* content, RtpHeaderExtensions rtp_header_extensions = GetFilteredRtpHeaderExtensions(video->rtp_header_extensions()); - UpdateRtpHeaderExtensionMap(rtp_header_extensions); if (!SetRtpTransportParameters(content, type, CS_REMOTE, rtp_header_extensions, error_desc)) { return false; @@ -1513,6 +1459,7 @@ RtpDataChannel::~RtpDataChannel() { TRACE_EVENT0("webrtc", "RtpDataChannel::~RtpDataChannel"); // this can't be done in the base class, since it calls a virtual DisableMedia_w(); + Deinit(); } @@ -1581,7 +1528,6 @@ bool RtpDataChannel::SetLocalContent_w(const MediaContentDescription* content, RtpHeaderExtensions rtp_header_extensions = GetFilteredRtpHeaderExtensions(data->rtp_header_extensions()); - UpdateRtpHeaderExtensionMap(rtp_header_extensions); if (!SetRtpTransportParameters(content, type, CS_LOCAL, rtp_header_extensions, error_desc)) { @@ -1595,16 +1541,9 @@ bool RtpDataChannel::SetLocalContent_w(const MediaContentDescription* content, error_desc); return false; } - for (const DataCodec& codec : data->codecs()) { AddHandledPayloadType(codec.id); } - // Need to re-register the sink to update the handled payload. - if (!RegisterRtpDemuxerSink()) { - RTC_LOG(LS_ERROR) << "Failed to set up data demuxing."; - return false; - } - last_recv_params_ = recv_params; // TODO(pthatcher): Move local streams into DataSendParameters, and @@ -1648,7 +1587,6 @@ bool RtpDataChannel::SetRemoteContent_w(const MediaContentDescription* content, RtpHeaderExtensions rtp_header_extensions = GetFilteredRtpHeaderExtensions(data->rtp_header_extensions()); - UpdateRtpHeaderExtensionMap(rtp_header_extensions); RTC_LOG(LS_INFO) << "Setting remote data description"; if (!SetRtpTransportParameters(content, type, CS_REMOTE, rtp_header_extensions, error_desc)) { diff --git a/pc/channel.h b/pc/channel.h index 52474c3a14..6a8367a754 100644 --- a/pc/channel.h +++ b/pc/channel.h @@ -23,7 +23,6 @@ #include "api/rtpreceiverinterface.h" #include "api/videosinkinterface.h" #include "api/videosourceinterface.h" -#include "call/rtp_packet_sink_interface.h" #include "media/base/mediachannel.h" #include "media/base/mediaengine.h" #include "media/base/streamparams.h" @@ -70,10 +69,9 @@ class MediaContentDescription; // vtable, and the media channel's thread using BaseChannel as the // NetworkInterface. -class BaseChannel : public rtc::MessageHandler, - public sigslot::has_slots<>, - public MediaChannel::NetworkInterface, - public webrtc::RtpPacketSinkInterface { +class BaseChannel + : public rtc::MessageHandler, public sigslot::has_slots<>, + public MediaChannel::NetworkInterface { public: // If |srtp_required| is true, the channel will not send or receive any // RTP/RTCP packets without using SRTP (either using SDES or DTLS-SRTP). @@ -195,8 +193,10 @@ class BaseChannel : public rtc::MessageHandler, virtual cricket::MediaType media_type() = 0; - // RtpPacketSinkInterface overrides. - void OnRtpPacket(const webrtc::RtpPacketReceived& packet) override; + // Public for testing. + // TODO(zstein): Remove this once channels register themselves with + // an RtpTransport in a more explicit way. + bool HandlesPayloadType(int payload_type) const; // Used by the RTCStatsCollector tests to set the transport name without // creating RtpTransports. @@ -264,10 +264,12 @@ class BaseChannel : public rtc::MessageHandler, rtc::CopyOnWriteBuffer* packet, const rtc::PacketOptions& options); - void OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time); + bool WantsPacket(bool rtcp, const rtc::CopyOnWriteBuffer* packet); + void HandlePacket(bool rtcp, rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time); + // TODO(zstein): packet can be const once the RtpTransport handles protection. void OnPacketReceived(bool rtcp, - const rtc::CopyOnWriteBuffer& packet, + rtc::CopyOnWriteBuffer* packet, const rtc::PacketTime& packet_time); void ProcessPacket(bool rtcp, const rtc::CopyOnWriteBuffer& packet, @@ -358,11 +360,6 @@ class BaseChannel : public rtc::MessageHandler, void AddHandledPayloadType(int payload_type); - void UpdateRtpHeaderExtensionMap( - const RtpHeaderExtensions& header_extensions); - - bool RegisterRtpDemuxerSink(); - private: void ConnectToRtpTransport(); void DisconnectFromRtpTransport(); @@ -442,9 +439,6 @@ class BaseChannel : public rtc::MessageHandler, // The cached encrypted header extension IDs. rtc::Optional> cached_send_extension_ids_; rtc::Optional> cached_recv_extension_ids_; - - bool encryption_disabled_ = false; - webrtc::RtpDemuxerCriteria demuxer_criteria_; }; // VoiceChannel is a specialization that adds support for early media, DTMF, diff --git a/pc/channel_unittest.cc b/pc/channel_unittest.cc index 753d6cd618..7ee35013ac 100644 --- a/pc/channel_unittest.cc +++ b/pc/channel_unittest.cc @@ -1609,6 +1609,10 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { EXPECT_TRUE(SendAccept()); EXPECT_EQ(rtcp_mux, !channel1_->NeedsRtcpTransport()); EXPECT_EQ(rtcp_mux, !channel2_->NeedsRtcpTransport()); + EXPECT_TRUE(channel1_->HandlesPayloadType(pl_type1)); + EXPECT_TRUE(channel2_->HandlesPayloadType(pl_type1)); + EXPECT_FALSE(channel1_->HandlesPayloadType(pl_type2)); + EXPECT_FALSE(channel2_->HandlesPayloadType(pl_type2)); // Both channels can receive pl_type1 only. SendCustomRtp1(kSsrc1, ++sequence_number1_1, pl_type1); @@ -1619,15 +1623,13 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { EXPECT_TRUE(CheckNoRtp1()); EXPECT_TRUE(CheckNoRtp2()); - EXPECT_TRUE(SendInitiate()); - EXPECT_TRUE(SendAccept()); + // RTCP test SendCustomRtp1(kSsrc1, ++sequence_number1_1, pl_type2); SendCustomRtp2(kSsrc2, ++sequence_number2_2, pl_type2); WaitForThreads(); EXPECT_FALSE(CheckCustomRtp2(kSsrc1, sequence_number1_1, pl_type2)); EXPECT_FALSE(CheckCustomRtp1(kSsrc2, sequence_number2_2, pl_type2)); - // RTCP test SendCustomRtcp1(kSsrc1); SendCustomRtcp2(kSsrc2); WaitForThreads(); diff --git a/pc/channelmanager.cc b/pc/channelmanager.cc index 2ae35f88eb..ead1da9cdd 100644 --- a/pc/channelmanager.cc +++ b/pc/channelmanager.cc @@ -507,6 +507,7 @@ void ChannelManager::DestroyRtpDataChannel(RtpDataChannel* data_channel) { if (it == data_channels_.end()) { return; } + data_channels_.erase(it); } diff --git a/pc/channelmanager_unittest.cc b/pc/channelmanager_unittest.cc index bf799f6d08..d318ac5b20 100644 --- a/pc/channelmanager_unittest.cc +++ b/pc/channelmanager_unittest.cc @@ -195,14 +195,11 @@ class ChannelManagerTestWithRtpTransport RTPTransportType type = GetParam(); switch (type) { case RTPTransportType::kRtp: - return rtc::MakeUnique( - /*rtcp_mux_required=*/true); + return CreatePlainRtpTransport(); case RTPTransportType::kSrtp: - return rtc::MakeUnique( - /*rtcp_mux_required=*/true); + return CreateSrtpTransport(); case RTPTransportType::kDtlsSrtp: - return rtc::MakeUnique( - /*rtcp_mux_required=*/true); + return CreateDtlsSrtpTransport(); } return nullptr; } @@ -227,6 +224,29 @@ class ChannelManagerTestWithRtpTransport cm_->DestroyRtpDataChannel(rtp_data_channel); cm_->Terminate(); } + + private: + std::unique_ptr CreatePlainRtpTransport() { + return rtc::MakeUnique(/*rtcp_mux_required=*/true); + } + + std::unique_ptr CreateSrtpTransport() { + auto rtp_transport = + rtc::MakeUnique(/*rtcp_mux_required=*/true); + auto srtp_transport = + rtc::MakeUnique(std::move(rtp_transport)); + return srtp_transport; + } + + std::unique_ptr CreateDtlsSrtpTransport() { + auto rtp_transport = + rtc::MakeUnique(/*rtcp_mux_required=*/true); + auto srtp_transport = + rtc::MakeUnique(std::move(rtp_transport)); + auto dtls_srtp_transport_ = + rtc::MakeUnique(std::move(srtp_transport)); + return dtls_srtp_transport_; + } }; TEST_P(ChannelManagerTestWithRtpTransport, CreateDestroyChannels) { diff --git a/pc/dtlssrtptransport.cc b/pc/dtlssrtptransport.cc index b85930c56d..0b98a96293 100644 --- a/pc/dtlssrtptransport.cc +++ b/pc/dtlssrtptransport.cc @@ -24,8 +24,22 @@ static const char kDtlsSrtpExporterLabel[] = "EXTRACTOR-dtls_srtp"; namespace webrtc { -DtlsSrtpTransport::DtlsSrtpTransport(bool rtcp_mux_enabled) - : SrtpTransport(rtcp_mux_enabled) {} +DtlsSrtpTransport::DtlsSrtpTransport( + std::unique_ptr srtp_transport) + : RtpTransportInternalAdapter(srtp_transport.get()) { + srtp_transport_ = std::move(srtp_transport); + RTC_DCHECK(srtp_transport_); + srtp_transport_->SignalPacketReceived.connect( + this, &DtlsSrtpTransport::OnPacketReceived); + srtp_transport_->SignalReadyToSend.connect(this, + &DtlsSrtpTransport::OnReadyToSend); + srtp_transport_->SignalNetworkRouteChanged.connect( + this, &DtlsSrtpTransport::OnNetworkRouteChanged); + srtp_transport_->SignalWritableState.connect( + this, &DtlsSrtpTransport::OnWritableState); + srtp_transport_->SignalSentPacket.connect(this, + &DtlsSrtpTransport::OnSentPacket); +} void DtlsSrtpTransport::SetDtlsTransports( cricket::DtlsTransportInternal* rtp_dtls_transport, @@ -40,7 +54,7 @@ void DtlsSrtpTransport::SetDtlsTransports( // DtlsTransport changes and wait until the DTLS handshake is complete to set // the newly negotiated parameters. if (IsActive()) { - ResetParams(); + srtp_transport_->ResetParams(); } const std::string transport_name = @@ -66,7 +80,7 @@ void DtlsSrtpTransport::SetDtlsTransports( } void DtlsSrtpTransport::SetRtcpMuxEnabled(bool enable) { - SrtpTransport::SetRtcpMuxEnabled(enable); + srtp_transport_->SetRtcpMuxEnabled(enable); if (enable) { UpdateWritableStateAndMaybeSetupDtlsSrtp(); } @@ -114,9 +128,10 @@ bool DtlsSrtpTransport::IsDtlsConnected() { } bool DtlsSrtpTransport::IsDtlsWritable() { + auto rtp_packet_transport = srtp_transport_->rtp_packet_transport(); auto rtcp_packet_transport = - rtcp_mux_enabled() ? nullptr : rtcp_dtls_transport_; - return rtp_dtls_transport_ && rtp_dtls_transport_->writable() && + rtcp_mux_enabled() ? nullptr : srtp_transport_->rtcp_packet_transport(); + return rtp_packet_transport && rtp_packet_transport->writable() && (!rtcp_packet_transport || rtcp_packet_transport->writable()); } @@ -155,10 +170,11 @@ void DtlsSrtpTransport::SetupRtpDtlsSrtp() { if (!ExtractParams(rtp_dtls_transport_, &selected_crypto_suite, &send_key, &recv_key) || - !SetRtpParams(selected_crypto_suite, &send_key[0], - static_cast(send_key.size()), send_extension_ids, - selected_crypto_suite, &recv_key[0], - static_cast(recv_key.size()), recv_extension_ids)) { + !srtp_transport_->SetRtpParams( + selected_crypto_suite, &send_key[0], + static_cast(send_key.size()), send_extension_ids, + selected_crypto_suite, &recv_key[0], + static_cast(recv_key.size()), recv_extension_ids)) { SignalDtlsSrtpSetupFailure(this, /*rtcp=*/false); RTC_LOG(LS_WARNING) << "DTLS-SRTP key installation for RTP failed"; } @@ -186,11 +202,11 @@ void DtlsSrtpTransport::SetupRtcpDtlsSrtp() { rtc::ZeroOnFreeBuffer rtcp_recv_key; if (!ExtractParams(rtcp_dtls_transport_, &selected_crypto_suite, &rtcp_send_key, &rtcp_recv_key) || - !SetRtcpParams(selected_crypto_suite, &rtcp_send_key[0], - static_cast(rtcp_send_key.size()), send_extension_ids, - selected_crypto_suite, &rtcp_recv_key[0], - static_cast(rtcp_recv_key.size()), - recv_extension_ids)) { + !srtp_transport_->SetRtcpParams( + selected_crypto_suite, &rtcp_send_key[0], + static_cast(rtcp_send_key.size()), send_extension_ids, + selected_crypto_suite, &rtcp_recv_key[0], + static_cast(rtcp_recv_key.size()), recv_extension_ids)) { SignalDtlsSrtpSetupFailure(this, /*rtcp=*/true); RTC_LOG(LS_WARNING) << "DTLS-SRTP key installation for RTCP failed"; } @@ -313,20 +329,37 @@ void DtlsSrtpTransport::OnDtlsState(cricket::DtlsTransportInternal* transport, transport == rtcp_dtls_transport_); if (state != cricket::DTLS_TRANSPORT_CONNECTED) { - ResetParams(); + srtp_transport_->ResetParams(); return; } MaybeSetupDtlsSrtp(); } -void DtlsSrtpTransport::OnWritableState( - rtc::PacketTransportInternal* packet_transport) { - bool writable = IsTransportWritable(); +void DtlsSrtpTransport::OnWritableState(bool writable) { SetWritable(writable); if (writable) { MaybeSetupDtlsSrtp(); } } +void DtlsSrtpTransport::OnSentPacket(const rtc::SentPacket& sent_packet) { + SignalSentPacket(sent_packet); +} + +void DtlsSrtpTransport::OnPacketReceived(bool rtcp, + rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time) { + SignalPacketReceived(rtcp, packet, packet_time); +} + +void DtlsSrtpTransport::OnReadyToSend(bool ready) { + SignalReadyToSend(ready); +} + +void DtlsSrtpTransport::OnNetworkRouteChanged( + rtc::Optional network_route) { + SignalNetworkRouteChanged(network_route); +} + } // namespace webrtc diff --git a/pc/dtlssrtptransport.h b/pc/dtlssrtptransport.h index fdd54b2615..02002b052a 100644 --- a/pc/dtlssrtptransport.h +++ b/pc/dtlssrtptransport.h @@ -16,17 +16,20 @@ #include #include "p2p/base/dtlstransportinternal.h" +#include "pc/rtptransportinternaladapter.h" #include "pc/srtptransport.h" #include "rtc_base/buffer.h" namespace webrtc { -// The subclass of SrtpTransport is used for DTLS-SRTP. When the DTLS handshake -// is finished, it extracts the keying materials from DtlsTransport and -// configures the SrtpSessions in the base class. -class DtlsSrtpTransport : public SrtpTransport { +// This class is intended to be used as an RtpTransport and it wraps both an +// SrtpTransport and DtlsTransports(RTP/RTCP). When the DTLS handshake is +// finished, it extracts the keying materials from DtlsTransport and sets them +// to SrtpTransport. +class DtlsSrtpTransport : public RtpTransportInternalAdapter { public: - explicit DtlsSrtpTransport(bool rtcp_mux_enabled); + explicit DtlsSrtpTransport( + std::unique_ptr srtp_transport); // Set P2P layer RTP/RTCP DtlsTransports. When using RTCP-muxing, // |rtcp_dtls_transport| is null. @@ -42,6 +45,15 @@ class DtlsSrtpTransport : public SrtpTransport { void UpdateRecvEncryptedHeaderExtensionIds( const std::vector& recv_extension_ids); + bool IsActive() { return srtp_transport_->IsActive(); } + + // Cache RTP Absoulute SendTime extension header ID. This is only used when + // external authentication is enabled. + void CacheRtpAbsSendTimeHeaderExtension(int rtp_abs_sendtime_extn_id) { + srtp_transport_->CacheRtpAbsSendTimeHeaderExtension( + rtp_abs_sendtime_extn_id); + } + // TODO(zhihuang): Remove this when we remove RtpTransportAdapter. RtpTransportAdapter* GetInternal() override { return nullptr; } @@ -71,11 +83,16 @@ class DtlsSrtpTransport : public SrtpTransport { void OnDtlsState(cricket::DtlsTransportInternal* dtls_transport, cricket::DtlsTransportState state); - - // Override the RtpTransport::OnWritableState. - void OnWritableState(rtc::PacketTransportInternal* packet_transport) override; + void OnWritableState(bool writable); + void OnSentPacket(const rtc::SentPacket& sent_packet); + void OnPacketReceived(bool rtcp, + rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time); + void OnReadyToSend(bool ready); + void OnNetworkRouteChanged(rtc::Optional network_route); bool writable_ = false; + std::unique_ptr srtp_transport_; // Owned by the TransportController. cricket::DtlsTransportInternal* rtp_dtls_transport_ = nullptr; cricket::DtlsTransportInternal* rtcp_dtls_transport_ = nullptr; diff --git a/pc/dtlssrtptransport_unittest.cc b/pc/dtlssrtptransport_unittest.cc index eb37b701db..08a8151ee7 100644 --- a/pc/dtlssrtptransport_unittest.cc +++ b/pc/dtlssrtptransport_unittest.cc @@ -33,26 +33,50 @@ using webrtc::RtpTransport; const int kRtpAuthTagLen = 10; +class TransportObserver : public sigslot::has_slots<> { + public: + void OnPacketReceived(bool rtcp, + rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time) { + rtcp ? last_recv_rtcp_packet_ = *packet : last_recv_rtp_packet_ = *packet; + } + + void OnReadyToSend(bool ready) { ready_to_send_ = ready; } + + rtc::CopyOnWriteBuffer last_recv_rtp_packet() { + return last_recv_rtp_packet_; + } + + rtc::CopyOnWriteBuffer last_recv_rtcp_packet() { + return last_recv_rtcp_packet_; + } + + bool ready_to_send() { return ready_to_send_; } + + private: + rtc::CopyOnWriteBuffer last_recv_rtp_packet_; + rtc::CopyOnWriteBuffer last_recv_rtcp_packet_; + bool ready_to_send_ = false; +}; + class DtlsSrtpTransportTest : public testing::Test, public sigslot::has_slots<> { protected: DtlsSrtpTransportTest() {} - ~DtlsSrtpTransportTest() { - if (dtls_srtp_transport1_) { - dtls_srtp_transport1_->UnregisterRtpDemuxerSink(&transport_observer1_); - } - if (dtls_srtp_transport2_) { - dtls_srtp_transport2_->UnregisterRtpDemuxerSink(&transport_observer2_); - } - } - std::unique_ptr MakeDtlsSrtpTransport( FakeDtlsTransport* rtp_dtls, FakeDtlsTransport* rtcp_dtls, bool rtcp_mux_enabled) { + auto rtp_transport = rtc::MakeUnique(rtcp_mux_enabled); + + rtp_transport->AddHandledPayloadType(0x00); + rtp_transport->AddHandledPayloadType(0xc9); + + auto srtp_transport = + rtc::MakeUnique(std::move(rtp_transport)); auto dtls_srtp_transport = - rtc::MakeUnique(rtcp_mux_enabled); + rtc::MakeUnique(std::move(srtp_transport)); dtls_srtp_transport->SetDtlsTransports(rtp_dtls, rtcp_dtls); @@ -69,24 +93,15 @@ class DtlsSrtpTransportTest : public testing::Test, dtls_srtp_transport2_ = MakeDtlsSrtpTransport(rtp_dtls2, rtcp_dtls2, rtcp_mux_enabled); - dtls_srtp_transport1_->SignalRtcpPacketReceived.connect( - &transport_observer1_, - &webrtc::TransportObserver::OnRtcpPacketReceived); + dtls_srtp_transport1_->SignalPacketReceived.connect( + &transport_observer1_, &TransportObserver::OnPacketReceived); dtls_srtp_transport1_->SignalReadyToSend.connect( - &transport_observer1_, &webrtc::TransportObserver::OnReadyToSend); + &transport_observer1_, &TransportObserver::OnReadyToSend); - dtls_srtp_transport2_->SignalRtcpPacketReceived.connect( - &transport_observer2_, - &webrtc::TransportObserver::OnRtcpPacketReceived); + dtls_srtp_transport2_->SignalPacketReceived.connect( + &transport_observer2_, &TransportObserver::OnPacketReceived); dtls_srtp_transport2_->SignalReadyToSend.connect( - &transport_observer2_, &webrtc::TransportObserver::OnReadyToSend); - webrtc::RtpDemuxerCriteria demuxer_criteria; - // 0x00 is the payload type used in kPcmuFrame. - demuxer_criteria.payload_types = {0x00}; - dtls_srtp_transport1_->RegisterRtpDemuxerSink(demuxer_criteria, - &transport_observer1_); - dtls_srtp_transport2_->RegisterRtpDemuxerSink(demuxer_criteria, - &transport_observer2_); + &transport_observer2_, &TransportObserver::OnReadyToSend); } void CompleteDtlsHandshake(FakeDtlsTransport* fake_dtls1, @@ -236,8 +251,8 @@ class DtlsSrtpTransportTest : public testing::Test, std::unique_ptr dtls_srtp_transport1_; std::unique_ptr dtls_srtp_transport2_; - webrtc::TransportObserver transport_observer1_; - webrtc::TransportObserver transport_observer2_; + TransportObserver transport_observer1_; + TransportObserver transport_observer2_; int sequence_number_ = 0; }; diff --git a/pc/jseptransport2_unittest.cc b/pc/jseptransport2_unittest.cc index e578e6b5ec..fc098ae9b6 100644 --- a/pc/jseptransport2_unittest.cc +++ b/pc/jseptransport2_unittest.cc @@ -43,8 +43,9 @@ class JsepTransport2Test : public testing::Test, public sigslot::has_slots<> { const std::string& transport_name, rtc::PacketTransportInternal* rtp_packet_transport, rtc::PacketTransportInternal* rtcp_packet_transport) { - auto srtp_transport = rtc::MakeUnique( - rtcp_packet_transport == nullptr); + bool rtcp_mux_enabled = (rtcp_packet_transport == nullptr); + auto srtp_transport = + rtc::MakeUnique(rtcp_mux_enabled); srtp_transport->SetRtpPacketTransport(rtp_packet_transport); if (rtcp_packet_transport) { @@ -57,8 +58,11 @@ class JsepTransport2Test : public testing::Test, public sigslot::has_slots<> { const std::string& transport_name, cricket::DtlsTransportInternal* rtp_dtls_transport, cricket::DtlsTransportInternal* rtcp_dtls_transport) { - auto dtls_srtp_transport = rtc::MakeUnique( - rtcp_dtls_transport == nullptr); + bool rtcp_mux_enabled = (rtcp_dtls_transport == nullptr); + auto srtp_transport = + rtc::MakeUnique(rtcp_mux_enabled); + auto dtls_srtp_transport = + rtc::MakeUnique(std::move(srtp_transport)); dtls_srtp_transport->SetDtlsTransports(rtp_dtls_transport, rtcp_dtls_transport); diff --git a/pc/jseptransportcontroller.cc b/pc/jseptransportcontroller.cc index 10725b5127..5235791792 100644 --- a/pc/jseptransportcontroller.cc +++ b/pc/jseptransportcontroller.cc @@ -457,13 +457,16 @@ JsepTransportController::CreateDtlsSrtpTransport( cricket::DtlsTransportInternal* rtp_dtls_transport, cricket::DtlsTransportInternal* rtcp_dtls_transport) { RTC_DCHECK(network_thread_->IsCurrent()); - - auto dtls_srtp_transport = rtc::MakeUnique( - rtcp_dtls_transport == nullptr); + bool rtcp_mux_enabled = rtcp_dtls_transport == nullptr; + auto srtp_transport = + rtc::MakeUnique(rtcp_mux_enabled); if (config_.enable_external_auth) { - dtls_srtp_transport->EnableExternalAuth(); + srtp_transport->EnableExternalAuth(); } + auto dtls_srtp_transport = + rtc::MakeUnique(std::move(srtp_transport)); + dtls_srtp_transport->SetDtlsTransports(rtp_dtls_transport, rtcp_dtls_transport); return dtls_srtp_transport; diff --git a/pc/peerconnection.cc b/pc/peerconnection.cc index 538f5c1f78..eacfb8329b 100644 --- a/pc/peerconnection.cc +++ b/pc/peerconnection.cc @@ -6195,9 +6195,12 @@ void PeerConnection::DestroyDataChannel() { void PeerConnection::DestroyBaseChannel(cricket::BaseChannel* channel) { RTC_DCHECK(channel); + RTC_DCHECK(channel->rtp_dtls_transport()); + // Need to cache these before destroying the base channel so that we do not // access uninitialized memory. - const std::string transport_name = channel->transport_name(); + const std::string transport_name = + channel->rtp_dtls_transport()->transport_name(); const bool need_to_delete_rtcp = (channel->rtcp_dtls_transport() != nullptr); switch (channel->media_type()) { diff --git a/pc/peerconnection_media_unittest.cc b/pc/peerconnection_media_unittest.cc index 6dabb6e939..8d1dd7694f 100644 --- a/pc/peerconnection_media_unittest.cc +++ b/pc/peerconnection_media_unittest.cc @@ -959,8 +959,8 @@ void RenameContent(cricket::SessionDescription* desc, // Tests that an answer responds with the same MIDs as the offer. TEST_P(PeerConnectionMediaTest, AnswerHasSameMidsAsOffer) { - const std::string kAudioMid = "notdefault1"; - const std::string kVideoMid = "notdefault2"; + const std::string kAudioMid = "not default1"; + const std::string kVideoMid = "not default2"; auto caller = CreatePeerConnectionWithAudioVideo(); auto callee = CreatePeerConnectionWithAudioVideo(); @@ -980,8 +980,8 @@ TEST_P(PeerConnectionMediaTest, AnswerHasSameMidsAsOffer) { // Test that if the callee creates a re-offer, the MIDs are the same as the // original offer. TEST_P(PeerConnectionMediaTest, ReOfferHasSameMidsAsFirstOffer) { - const std::string kAudioMid = "notdefault1"; - const std::string kVideoMid = "notdefault2"; + const std::string kAudioMid = "not default1"; + const std::string kVideoMid = "not default2"; auto caller = CreatePeerConnectionWithAudioVideo(); auto callee = CreatePeerConnectionWithAudioVideo(); diff --git a/pc/rtptransport.cc b/pc/rtptransport.cc index f59be3bc64..26f7e3e4c9 100644 --- a/pc/rtptransport.cc +++ b/pc/rtptransport.cc @@ -10,10 +10,7 @@ #include "pc/rtptransport.h" -#include - #include "media/base/rtputils.h" -#include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "p2p/base/p2pconstants.h" #include "p2p/base/packettransportinterface.h" #include "rtc_base/checks.h" @@ -47,7 +44,7 @@ void RtpTransport::SetRtpPacketTransport( new_packet_transport->SignalReadPacket.connect(this, &RtpTransport::OnReadPacket); new_packet_transport->SignalNetworkRouteChanged.connect( - this, &RtpTransport::OnNetworkRouteChanged); + this, &RtpTransport::OnNetworkRouteChange); new_packet_transport->SignalWritableState.connect( this, &RtpTransport::OnWritableState); new_packet_transport->SignalSentPacket.connect(this, @@ -83,7 +80,7 @@ void RtpTransport::SetRtcpPacketTransport( new_packet_transport->SignalReadPacket.connect(this, &RtpTransport::OnReadPacket); new_packet_transport->SignalNetworkRouteChanged.connect( - this, &RtpTransport::OnNetworkRouteChanged); + this, &RtpTransport::OnNetworkRouteChange); new_packet_transport->SignalWritableState.connect( this, &RtpTransport::OnWritableState); new_packet_transport->SignalSentPacket.connect(this, @@ -137,27 +134,16 @@ bool RtpTransport::SendPacket(bool rtcp, return true; } -void RtpTransport::UpdateRtpHeaderExtensionMap( - const cricket::RtpHeaderExtensions& header_extensions) { - header_extension_map_ = RtpHeaderExtensionMap(header_extensions); +bool RtpTransport::HandlesPacket(const uint8_t* data, size_t len) { + return bundle_filter_.DemuxPacket(data, len); } -bool RtpTransport::RegisterRtpDemuxerSink(const RtpDemuxerCriteria& criteria, - RtpPacketSinkInterface* sink) { - rtp_demuxer_.RemoveSink(sink); - if (!rtp_demuxer_.AddSink(criteria, sink)) { - RTC_LOG(LS_ERROR) << "Failed to register the sink for RTP demuxer."; - return false; - } - return true; +bool RtpTransport::HandlesPayloadType(int payload_type) const { + return bundle_filter_.FindPayloadType(payload_type); } -bool RtpTransport::UnregisterRtpDemuxerSink(RtpPacketSinkInterface* sink) { - if (!rtp_demuxer_.RemoveSink(sink)) { - RTC_LOG(LS_ERROR) << "Failed to unregister the sink for RTP demuxer."; - return false; - } - return true; +void RtpTransport::AddHandledPayloadType(int payload_type) { + bundle_filter_.AddPayloadType(payload_type); } PacketTransportInterface* RtpTransport::GetRtpPacketTransport() const { @@ -194,26 +180,11 @@ RtpTransportParameters RtpTransport::GetParameters() const { return parameters_; } -void RtpTransport::DemuxPacket(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& time) { - webrtc::RtpPacketReceived parsed_packet(&header_extension_map_); - if (!parsed_packet.Parse(std::move(*packet))) { - RTC_LOG(LS_ERROR) - << "Failed to parse the incoming RTP packet before demuxing. Drop it."; - return; - } - - if (time.timestamp != -1) { - parsed_packet.set_arrival_time_ms((time.timestamp + 500) / 1000); - } - rtp_demuxer_.OnRtpPacket(parsed_packet); -} - RtpTransportAdapter* RtpTransport::GetInternal() { return nullptr; } -bool RtpTransport::IsTransportWritable() { +bool RtpTransport::IsRtpTransportWritable() { auto rtcp_packet_transport = rtcp_mux_enabled_ ? nullptr : rtcp_packet_transport_; return rtp_packet_transport_ && rtp_packet_transport_->writable() && @@ -224,7 +195,7 @@ void RtpTransport::OnReadyToSend(rtc::PacketTransportInternal* transport) { SetReadyToSend(transport == rtcp_packet_transport_, true); } -void RtpTransport::OnNetworkRouteChanged( +void RtpTransport::OnNetworkRouteChange( rtc::Optional network_route) { SignalNetworkRouteChanged(network_route); } @@ -233,7 +204,7 @@ void RtpTransport::OnWritableState( rtc::PacketTransportInternal* packet_transport) { RTC_DCHECK(packet_transport == rtp_packet_transport_ || packet_transport == rtcp_packet_transport_); - SignalWritableState(IsTransportWritable()); + SignalWritableState(IsRtpTransportWritable()); } void RtpTransport::OnSentPacket(rtc::PacketTransportInternal* packet_transport, @@ -243,44 +214,6 @@ void RtpTransport::OnSentPacket(rtc::PacketTransportInternal* packet_transport, SignalSentPacket(sent_packet); } -void RtpTransport::OnRtpPacketReceived(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time) { - DemuxPacket(packet, packet_time); -} - -void RtpTransport::OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time) { - SignalRtcpPacketReceived(packet, packet_time); -} - -void RtpTransport::OnReadPacket(rtc::PacketTransportInternal* transport, - const char* data, - size_t len, - const rtc::PacketTime& packet_time, - int flags) { - TRACE_EVENT0("webrtc", "RtpTransport::OnReadPacket"); - - // When using RTCP multiplexing we might get RTCP packets on the RTP - // transport. We check the RTP payload type to determine if it is RTCP. - bool rtcp = transport == rtcp_packet_transport() || - cricket::IsRtcp(data, static_cast(len)); - rtc::CopyOnWriteBuffer packet(data, len); - - // Protect ourselves against crazy data. - if (!cricket::IsValidRtpRtcpPacketSize(rtcp, packet.size())) { - RTC_LOG(LS_ERROR) << "Dropping incoming " - << cricket::RtpRtcpStringLiteral(rtcp) - << " packet: wrong size=" << packet.size(); - return; - } - - if (rtcp) { - OnRtcpPacketReceived(&packet, packet_time); - } else { - OnRtpPacketReceived(&packet, packet_time); - } -} - void RtpTransport::SetReadyToSend(bool rtcp, bool ready) { if (rtcp) { rtcp_ready_to_send_ = ready; @@ -300,4 +233,51 @@ void RtpTransport::MaybeSignalReadyToSend() { } } +// Check the RTP payload type. If 63 < payload type < 96, it's RTCP. +// For additional details, see http://tools.ietf.org/html/rfc5761. +bool IsRtcp(const char* data, int len) { + if (len < 2) { + return false; + } + char pt = data[1] & 0x7F; + return (63 < pt) && (pt < 96); +} + +void RtpTransport::OnReadPacket(rtc::PacketTransportInternal* transport, + const char* data, + size_t len, + const rtc::PacketTime& packet_time, + int flags) { + TRACE_EVENT0("webrtc", "RtpTransport::OnReadPacket"); + + // When using RTCP multiplexing we might get RTCP packets on the RTP + // transport. We check the RTP payload type to determine if it is RTCP. + bool rtcp = transport == rtcp_packet_transport() || + IsRtcp(data, static_cast(len)); + rtc::CopyOnWriteBuffer packet(data, len); + + if (!WantsPacket(rtcp, &packet)) { + return; + } + // This mutates |packet| if it is protected. + SignalPacketReceived(rtcp, &packet, packet_time); +} + +bool RtpTransport::WantsPacket(bool rtcp, + const rtc::CopyOnWriteBuffer* packet) { + // Protect ourselves against crazy data. + if (!packet || !cricket::IsValidRtpRtcpPacketSize(rtcp, packet->size())) { + RTC_LOG(LS_ERROR) << "Dropping incoming " + << cricket::RtpRtcpStringLiteral(rtcp) + << " packet: wrong size=" << packet->size(); + return false; + } + if (rtcp) { + // Permit all (seemingly valid) RTCP packets. + return true; + } + // Check whether we handle this payload. + return HandlesPacket(packet->data(), packet->size()); +} + } // namespace webrtc diff --git a/pc/rtptransport.h b/pc/rtptransport.h index a857fea3dd..637d447454 100644 --- a/pc/rtptransport.h +++ b/pc/rtptransport.h @@ -13,8 +13,7 @@ #include -#include "call/rtp_demuxer.h" -#include "modules/rtp_rtcp/include/rtp_header_extension_map.h" +#include "pc/bundlefilter.h" #include "pc/rtptransportinternal.h" #include "rtc_base/sigslot.h" @@ -67,13 +66,9 @@ class RtpTransport : public RtpTransportInternal { const rtc::PacketOptions& options, int flags) override; - void UpdateRtpHeaderExtensionMap( - const cricket::RtpHeaderExtensions& header_extensions) override; + bool HandlesPayloadType(int payload_type) const override; - bool RegisterRtpDemuxerSink(const RtpDemuxerCriteria& criteria, - RtpPacketSinkInterface* sink) override; - - bool UnregisterRtpDemuxerSink(RtpPacketSinkInterface* sink) override; + void AddHandledPayloadType(int payload_type) override; void SetMetricsObserver( rtc::scoped_refptr metrics_observer) override {} @@ -82,35 +77,15 @@ class RtpTransport : public RtpTransportInternal { // TODO(zstein): Remove this when we remove RtpTransportAdapter. RtpTransportAdapter* GetInternal() override; - // These methods will be used in the subclasses. - void DemuxPacket(rtc::CopyOnWriteBuffer* packet, const rtc::PacketTime& time); - - bool SendPacket(bool rtcp, - rtc::CopyOnWriteBuffer* packet, - const rtc::PacketOptions& options, - int flags); - - bool IsTransportWritable(); - - // Overridden by SrtpTransport. - virtual void OnNetworkRouteChanged( - rtc::Optional network_route); - virtual void OnRtpPacketReceived(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time); - virtual void OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time); - // Overridden by DtlsSrtpTransport. - virtual void OnWritableState(rtc::PacketTransportInternal* packet_transport); - private: + bool IsRtpTransportWritable(); + bool HandlesPacket(const uint8_t* data, size_t len); + void OnReadyToSend(rtc::PacketTransportInternal* transport); + void OnNetworkRouteChange(rtc::Optional network_route); + void OnWritableState(rtc::PacketTransportInternal* packet_transport); void OnSentPacket(rtc::PacketTransportInternal* packet_transport, const rtc::SentPacket& sent_packet); - void OnReadPacket(rtc::PacketTransportInternal* transport, - const char* data, - size_t len, - const rtc::PacketTime& packet_time, - int flags); // Updates "ready to send" for an individual channel and fires // SignalReadyToSend. @@ -118,17 +93,18 @@ class RtpTransport : public RtpTransportInternal { void MaybeSignalReadyToSend(); - // SRTP specific methods. - // TODO(zhihuang): Improve the inheritance model so that the RtpTransport - // doesn't need to implement SRTP specfic methods. - RTCError SetSrtpSendKey(const cricket::CryptoParams& params) override { - RTC_NOTREACHED(); - return RTCError::OK(); - } - RTCError SetSrtpReceiveKey(const cricket::CryptoParams& params) override { - RTC_NOTREACHED(); - return RTCError::OK(); - } + bool SendPacket(bool rtcp, + rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options, + int flags); + + void OnReadPacket(rtc::PacketTransportInternal* transport, + const char* data, + size_t len, + const rtc::PacketTime& packet_time, + int flags); + + bool WantsPacket(bool rtcp, const rtc::CopyOnWriteBuffer* packet); bool rtcp_mux_enabled_; @@ -140,10 +116,8 @@ class RtpTransport : public RtpTransportInternal { bool rtcp_ready_to_send_ = false; RtpTransportParameters parameters_; - RtpDemuxer rtp_demuxer_; - // Used for identifying the MID for RtpDemuxer. - RtpHeaderExtensionMap header_extension_map_; + cricket::BundleFilter bundle_filter_; }; } // namespace webrtc diff --git a/pc/rtptransport_unittest.cc b/pc/rtptransport_unittest.cc index efc2e8ccf7..3876aa3998 100644 --- a/pc/rtptransport_unittest.cc +++ b/pc/rtptransport_unittest.cc @@ -197,37 +197,49 @@ TEST(RtpTransportTest, SetRtcpTransportWithNetworkRouteChanged) { EXPECT_FALSE(observer.network_route()); } +class SignalCounter : public sigslot::has_slots<> { + public: + explicit SignalCounter(RtpTransport* transport) { + transport->SignalReadyToSend.connect(this, &SignalCounter::OnReadyToSend); + } + int count() const { return count_; } + void OnReadyToSend(bool ready) { ++count_; } + + private: + int count_ = 0; +}; + TEST(RtpTransportTest, ChangingReadyToSendStateOnlySignalsWhenChanged) { RtpTransport transport(kMuxEnabled); - TransportObserver observer(&transport); + SignalCounter observer(&transport); rtc::FakePacketTransport fake_rtp("fake_rtp"); fake_rtp.SetWritable(true); // State changes, so we should signal. transport.SetRtpPacketTransport(&fake_rtp); - EXPECT_EQ(observer.ready_to_send_signal_count(), 1); + EXPECT_EQ(observer.count(), 1); // State does not change, so we should not signal. transport.SetRtpPacketTransport(&fake_rtp); - EXPECT_EQ(observer.ready_to_send_signal_count(), 1); + EXPECT_EQ(observer.count(), 1); // State does not change, so we should not signal. transport.SetRtcpMuxEnabled(true); - EXPECT_EQ(observer.ready_to_send_signal_count(), 1); + EXPECT_EQ(observer.count(), 1); // State changes, so we should signal. transport.SetRtcpMuxEnabled(false); - EXPECT_EQ(observer.ready_to_send_signal_count(), 2); + EXPECT_EQ(observer.count(), 2); } // Test that SignalPacketReceived fires with rtcp=true when a RTCP packet is // received. TEST(RtpTransportTest, SignalDemuxedRtcp) { RtpTransport transport(kMuxDisabled); + SignalPacketReceivedCounter observer(&transport); rtc::FakePacketTransport fake_rtp("fake_rtp"); fake_rtp.SetDestination(&fake_rtp, true); transport.SetRtpPacketTransport(&fake_rtp); - TransportObserver observer(&transport); // An rtcp packet. const char data[] = {0, 73, 0, 0}; @@ -247,15 +259,11 @@ static const int kRtpLen = 12; // handled payload type is received. TEST(RtpTransportTest, SignalHandledRtpPayloadType) { RtpTransport transport(kMuxDisabled); + SignalPacketReceivedCounter observer(&transport); rtc::FakePacketTransport fake_rtp("fake_rtp"); fake_rtp.SetDestination(&fake_rtp, true); - // Disable the encryption to allow raw RTP data. transport.SetRtpPacketTransport(&fake_rtp); - TransportObserver observer(&transport); - RtpDemuxerCriteria demuxer_criteria; - // Add a handled payload type. - demuxer_criteria.payload_types = {0x11}; - transport.RegisterRtpDemuxerSink(demuxer_criteria, &observer); + transport.AddHandledPayloadType(0x11); // An rtp packet. const rtc::PacketOptions options; @@ -264,22 +272,16 @@ TEST(RtpTransportTest, SignalHandledRtpPayloadType) { fake_rtp.SendPacket(rtp_data.data(), kRtpLen, options, flags); EXPECT_EQ(1, observer.rtp_count()); EXPECT_EQ(0, observer.rtcp_count()); - // Remove the sink before destroying the transport. - transport.UnregisterRtpDemuxerSink(&observer); } // Test that SignalPacketReceived does not fire when a RTP packet with an // unhandled payload type is received. TEST(RtpTransportTest, DontSignalUnhandledRtpPayloadType) { RtpTransport transport(kMuxDisabled); + SignalPacketReceivedCounter observer(&transport); rtc::FakePacketTransport fake_rtp("fake_rtp"); fake_rtp.SetDestination(&fake_rtp, true); transport.SetRtpPacketTransport(&fake_rtp); - TransportObserver observer(&transport); - RtpDemuxerCriteria demuxer_criteria; - // Add an unhandled payload type. - demuxer_criteria.payload_types = {0x12}; - transport.RegisterRtpDemuxerSink(demuxer_criteria, &observer); const rtc::PacketOptions options; const int flags = 0; @@ -287,8 +289,6 @@ TEST(RtpTransportTest, DontSignalUnhandledRtpPayloadType) { fake_rtp.SendPacket(rtp_data.data(), kRtpLen, options, flags); EXPECT_EQ(0, observer.rtp_count()); EXPECT_EQ(0, observer.rtcp_count()); - // Remove the sink before destroying the transport. - transport.UnregisterRtpDemuxerSink(&observer); } } // namespace webrtc diff --git a/pc/rtptransportinternal.h b/pc/rtptransportinternal.h index 4b35a57666..0665fd76bc 100644 --- a/pc/rtptransportinternal.h +++ b/pc/rtptransportinternal.h @@ -13,11 +13,9 @@ #include -#include "api/ortc/srtptransportinterface.h" +#include "api/ortc/rtptransportinterface.h" #include "api/umametrics.h" -#include "call/rtp_demuxer.h" #include "p2p/base/icetransportinternal.h" -#include "pc/sessiondescription.h" #include "rtc_base/networkroute.h" #include "rtc_base/sigslot.h" @@ -29,11 +27,11 @@ struct PacketTime; namespace webrtc { -// This represents the internal interface beneath SrtpTransportInterface; +// This represents the internal interface beneath RtpTransportInterface; // it is not accessible to API consumers but is accessible to internal classes // in order to send and receive RTP and RTCP packets belonging to a single RTP // session. Additional convenience and configuration methods are also provided. -class RtpTransportInternal : public SrtpTransportInterface, +class RtpTransportInternal : public RtpTransportInterface, public sigslot::has_slots<> { public: virtual void SetRtcpMuxEnabled(bool enable) = 0; @@ -54,11 +52,11 @@ class RtpTransportInternal : public SrtpTransportInterface, // than just "writable"; it means the last send didn't return ENOTCONN. sigslot::signal1 SignalReadyToSend; - // Called whenever an RTCP packet is received. There is no equivalent signal - // for RTP packets because they would be forwarded to the BaseChannel through - // the RtpDemuxer callback. - sigslot::signal2 - SignalRtcpPacketReceived; + // TODO(zstein): Consider having two signals - RtpPacketReceived and + // RtcpPacketReceived. + // The first argument is true for RTCP packets and false for RTP packets. + sigslot::signal3 + SignalPacketReceived; // Called whenever the network route of the P2P layer transport changes. // The argument is an optional network route. @@ -82,22 +80,12 @@ class RtpTransportInternal : public SrtpTransportInterface, const rtc::PacketOptions& options, int flags) = 0; - // This method updates the RTP header extension map so that the RTP transport - // can parse the received packets and identify the MID. This is called by the - // BaseChannel when setting the content description. - // - // Note: This doesn't take the BUNDLE case in account meaning the RTP header - // extension maps are not merged when BUNDLE is enabled. This is fine because - // the ID for MID should be consistent among all the RTP transports. - virtual void UpdateRtpHeaderExtensionMap( - const cricket::RtpHeaderExtensions& header_extensions) = 0; + virtual bool HandlesPayloadType(int payload_type) const = 0; + + virtual void AddHandledPayloadType(int payload_type) = 0; virtual void SetMetricsObserver( rtc::scoped_refptr metrics_observer) = 0; - virtual bool RegisterRtpDemuxerSink(const RtpDemuxerCriteria& criteria, - RtpPacketSinkInterface* sink) = 0; - - virtual bool UnregisterRtpDemuxerSink(RtpPacketSinkInterface* sink) = 0; }; } // namespace webrtc diff --git a/pc/rtptransportinternaladapter.h b/pc/rtptransportinternaladapter.h index f46b3a57b4..6a2d7e2aa7 100644 --- a/pc/rtptransportinternaladapter.h +++ b/pc/rtptransportinternaladapter.h @@ -67,22 +67,12 @@ class RtpTransportInternalAdapter : public RtpTransportInternal { return transport_->SendRtcpPacket(packet, options, flags); } - void UpdateRtpHeaderExtensionMap( - const cricket::RtpHeaderExtensions& header_extensions) override { - transport_->UpdateRtpHeaderExtensionMap(header_extensions); + bool HandlesPayloadType(int payload_type) const override { + return transport_->HandlesPayloadType(payload_type); } - void RegisterRtpDemuxerSink(const RtpDemuxerCriteria& criteria, - RtpPacketSinkInterface* sink) override { - transport_->RegisterRtpDemuxerSink(criteria, sink); - } - - void UnregisterRtpDemuxerSink(RtpPacketSinkInterface* sink) override { - transport_->UnregisterRtpDemuxerSink(sink); - } - - void SetEncryptionDisabled(bool encryption_disabled) override { - transport_->SetEncryptionDisabled(encryption_disabled); + void AddHandledPayloadType(int payload_type) override { + return transport_->AddHandledPayloadType(payload_type); } // RtpTransportInterface overrides. diff --git a/pc/rtptransporttestutil.h b/pc/rtptransporttestutil.h index 9489fa2bfe..c2bdaad23e 100644 --- a/pc/rtptransporttestutil.h +++ b/pc/rtptransporttestutil.h @@ -11,67 +11,32 @@ #ifndef PC_RTPTRANSPORTTESTUTIL_H_ #define PC_RTPTRANSPORTTESTUTIL_H_ -#include "call/rtp_packet_sink_interface.h" -#include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "pc/rtptransportinternal.h" #include "rtc_base/sigslot.h" namespace webrtc { -// Used to handle the signals when the RtpTransport receives an RTP/RTCP packet. -// Used in Rtp/Srtp/DtlsTransport unit tests. -class TransportObserver : public RtpPacketSinkInterface, - public sigslot::has_slots<> { +class SignalPacketReceivedCounter : public sigslot::has_slots<> { public: - TransportObserver() {} - - explicit TransportObserver(RtpTransportInternal* rtp_transport) { - rtp_transport->SignalRtcpPacketReceived.connect( - this, &TransportObserver::OnRtcpPacketReceived); - rtp_transport->SignalReadyToSend.connect(this, - &TransportObserver::OnReadyToSend); + explicit SignalPacketReceivedCounter(RtpTransportInternal* transport) { + transport->SignalPacketReceived.connect( + this, &SignalPacketReceivedCounter::OnPacketReceived); } - - // RtpPacketInterface override. - void OnRtpPacket(const RtpPacketReceived& packet) override { - rtp_count_++; - last_recv_rtp_packet_ = packet.Buffer(); - } - - void OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time) { - RTC_LOG(LS_INFO) << "Received an RTCP packet."; - rtcp_count_++; - last_recv_rtcp_packet_ = *packet; - } - - int rtp_count() const { return rtp_count_; } int rtcp_count() const { return rtcp_count_; } - - rtc::CopyOnWriteBuffer last_recv_rtp_packet() { - return last_recv_rtp_packet_; - } - - rtc::CopyOnWriteBuffer last_recv_rtcp_packet() { - return last_recv_rtcp_packet_; - } - - void OnReadyToSend(bool ready) { - ready_to_send_signal_count_++; - ready_to_send_ = ready; - } - - bool ready_to_send() { return ready_to_send_; } - - int ready_to_send_signal_count() { return ready_to_send_signal_count_; } + int rtp_count() const { return rtp_count_; } private: - bool ready_to_send_ = false; - int rtp_count_ = 0; + void OnPacketReceived(bool rtcp, + rtc::CopyOnWriteBuffer*, + const rtc::PacketTime&) { + if (rtcp) { + ++rtcp_count_; + } else { + ++rtp_count_; + } + } int rtcp_count_ = 0; - int ready_to_send_signal_count_ = 0; - rtc::CopyOnWriteBuffer last_recv_rtp_packet_; - rtc::CopyOnWriteBuffer last_recv_rtcp_packet_; + int rtp_count_ = 0; }; } // namespace webrtc diff --git a/pc/srtptransport.cc b/pc/srtptransport.cc index 2409aee445..5eff3c932b 100644 --- a/pc/srtptransport.cc +++ b/pc/srtptransport.cc @@ -25,133 +25,155 @@ namespace webrtc { SrtpTransport::SrtpTransport(bool rtcp_mux_enabled) - : RtpTransport(rtcp_mux_enabled) {} + : RtpTransportInternalAdapter(new RtpTransport(rtcp_mux_enabled)) { + // Own the raw pointer |transport| from the base class. + rtp_transport_.reset(transport_); + RTC_DCHECK(rtp_transport_); + ConnectToRtpTransport(); +} + +SrtpTransport::SrtpTransport( + std::unique_ptr rtp_transport) + : RtpTransportInternalAdapter(rtp_transport.get()), + rtp_transport_(std::move(rtp_transport)) { + RTC_DCHECK(rtp_transport_); + ConnectToRtpTransport(); +} + +void SrtpTransport::ConnectToRtpTransport() { + rtp_transport_->SignalPacketReceived.connect( + this, &SrtpTransport::OnPacketReceived); + rtp_transport_->SignalReadyToSend.connect(this, + &SrtpTransport::OnReadyToSend); + rtp_transport_->SignalNetworkRouteChanged.connect( + this, &SrtpTransport::OnNetworkRouteChanged); + rtp_transport_->SignalWritableState.connect(this, + &SrtpTransport::OnWritableState); + rtp_transport_->SignalSentPacket.connect(this, &SrtpTransport::OnSentPacket); +} bool SrtpTransport::SendRtpPacket(rtc::CopyOnWriteBuffer* packet, const rtc::PacketOptions& options, int flags) { + return SendPacket(false, packet, options, flags); +} + +bool SrtpTransport::SendRtcpPacket(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options, + int flags) { + return SendPacket(true, packet, options, flags); +} + +bool SrtpTransport::SendPacket(bool rtcp, + rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options, + int flags) { if (!IsActive()) { RTC_LOG(LS_ERROR) << "Failed to send the packet because SRTP transport is inactive."; return false; } + rtc::PacketOptions updated_options = options; + rtc::CopyOnWriteBuffer cp = *packet; TRACE_EVENT0("webrtc", "SRTP Encode"); bool res; uint8_t* data = packet->data(); int len = static_cast(packet->size()); + if (!rtcp) { // If ENABLE_EXTERNAL_AUTH flag is on then packet authentication is not done // inside libsrtp for a RTP packet. A external HMAC module will be writing // a fake HMAC value. This is ONLY done for a RTP packet. // Socket layer will update rtp sendtime extension header if present in // packet with current time before updating the HMAC. #if !defined(ENABLE_EXTERNAL_AUTH) - res = ProtectRtp(data, len, static_cast(packet->capacity()), &len); -#else - if (!IsExternalAuthActive()) { res = ProtectRtp(data, len, static_cast(packet->capacity()), &len); - } else { - updated_options.packet_time_params.rtp_sendtime_extension_id = - rtp_abs_sendtime_extn_id_; - res = ProtectRtp(data, len, static_cast(packet->capacity()), &len, - &updated_options.packet_time_params.srtp_packet_index); - // If protection succeeds, let's get auth params from srtp. - if (res) { - uint8_t* auth_key = NULL; - int key_len; - res = GetRtpAuthParams( - &auth_key, &key_len, - &updated_options.packet_time_params.srtp_auth_tag_len); +#else + if (!IsExternalAuthActive()) { + res = ProtectRtp(data, len, static_cast(packet->capacity()), &len); + } else { + updated_options.packet_time_params.rtp_sendtime_extension_id = + rtp_abs_sendtime_extn_id_; + res = ProtectRtp(data, len, static_cast(packet->capacity()), &len, + &updated_options.packet_time_params.srtp_packet_index); + // If protection succeeds, let's get auth params from srtp. if (res) { - updated_options.packet_time_params.srtp_auth_key.resize(key_len); - updated_options.packet_time_params.srtp_auth_key.assign( - auth_key, auth_key + key_len); + uint8_t* auth_key = NULL; + int key_len; + res = GetRtpAuthParams( + &auth_key, &key_len, + &updated_options.packet_time_params.srtp_auth_tag_len); + if (res) { + updated_options.packet_time_params.srtp_auth_key.resize(key_len); + updated_options.packet_time_params.srtp_auth_key.assign( + auth_key, auth_key + key_len); + } } } - } #endif - if (!res) { - int seq_num = -1; - uint32_t ssrc = 0; - cricket::GetRtpSeqNum(data, len, &seq_num); - cricket::GetRtpSsrc(data, len, &ssrc); - RTC_LOG(LS_ERROR) << "Failed to protect RTP packet: size=" << len - << ", seqnum=" << seq_num << ", SSRC=" << ssrc; - return false; + if (!res) { + int seq_num = -1; + uint32_t ssrc = 0; + cricket::GetRtpSeqNum(data, len, &seq_num); + cricket::GetRtpSsrc(data, len, &ssrc); + RTC_LOG(LS_ERROR) << "Failed to protect RTP packet: size=" << len + << ", seqnum=" << seq_num << ", SSRC=" << ssrc; + return false; + } + } else { + res = ProtectRtcp(data, len, static_cast(packet->capacity()), &len); + if (!res) { + int type = -1; + cricket::GetRtcpType(data, len, &type); + RTC_LOG(LS_ERROR) << "Failed to protect RTCP packet: size=" << len + << ", type=" << type; + return false; + } } // Update the length of the packet now that we've added the auth tag. packet->SetSize(len); - return SendPacket(/*rtcp=*/false, packet, updated_options, flags); + return rtcp ? rtp_transport_->SendRtcpPacket(packet, updated_options, flags) + : rtp_transport_->SendRtpPacket(packet, updated_options, flags); } -bool SrtpTransport::SendRtcpPacket(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketOptions& options, - int flags) { - if (!IsActive()) { - RTC_LOG(LS_ERROR) - << "Failed to send the packet because SRTP transport is inactive."; - return false; - } - - TRACE_EVENT0("webrtc", "SRTP Encode"); - uint8_t* data = packet->data(); - int len = static_cast(packet->size()); - if (!ProtectRtcp(data, len, static_cast(packet->capacity()), &len)) { - int type = -1; - cricket::GetRtcpType(data, len, &type); - RTC_LOG(LS_ERROR) << "Failed to protect RTCP packet: size=" << len - << ", type=" << type; - return false; - } - // Update the length of the packet now that we've added the auth tag. - packet->SetSize(len); - - return SendPacket(/*rtcp=*/true, packet, options, flags); -} - -void SrtpTransport::OnRtpPacketReceived(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time) { +void SrtpTransport::OnPacketReceived(bool rtcp, + rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time) { if (!IsActive()) { RTC_LOG(LS_WARNING) - << "Inactive SRTP transport received an RTP packet. Drop it."; + << "Inactive SRTP transport received a packet. Drop it."; return; } + TRACE_EVENT0("webrtc", "SRTP Decode"); char* data = packet->data(); int len = static_cast(packet->size()); - if (!UnprotectRtp(data, len, &len)) { - int seq_num = -1; - uint32_t ssrc = 0; - cricket::GetRtpSeqNum(data, len, &seq_num); - cricket::GetRtpSsrc(data, len, &ssrc); - RTC_LOG(LS_ERROR) << "Failed to unprotect RTP packet: size=" << len - << ", seqnum=" << seq_num << ", SSRC=" << ssrc; - return; + bool res; + if (!rtcp) { + res = UnprotectRtp(data, len, &len); + if (!res) { + int seq_num = -1; + uint32_t ssrc = 0; + cricket::GetRtpSeqNum(data, len, &seq_num); + cricket::GetRtpSsrc(data, len, &ssrc); + RTC_LOG(LS_ERROR) << "Failed to unprotect RTP packet: size=" << len + << ", seqnum=" << seq_num << ", SSRC=" << ssrc; + return; + } + } else { + res = UnprotectRtcp(data, len, &len); + if (!res) { + int type = -1; + cricket::GetRtcpType(data, len, &type); + RTC_LOG(LS_ERROR) << "Failed to unprotect RTCP packet: size=" << len + << ", type=" << type; + return; + } } - packet->SetSize(len); - DemuxPacket(packet, packet_time); -} -void SrtpTransport::OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time) { - if (!IsActive()) { - RTC_LOG(LS_WARNING) - << "Inactive SRTP transport received an RTCP packet. Drop it."; - return; - } - TRACE_EVENT0("webrtc", "SRTP Decode"); - char* data = packet->data(); - int len = static_cast(packet->size()); - if (!UnprotectRtcp(data, len, &len)) { - int type = -1; - cricket::GetRtcpType(data, len, &type); - RTC_LOG(LS_ERROR) << "Failed to unprotect RTCP packet: size=" << len - << ", type=" << type; - return; - } packet->SetSize(len); - SignalRtcpPacketReceived(packet, packet_time); + SignalPacketReceived(rtcp, packet, packet_time); } void SrtpTransport::OnNetworkRouteChanged( @@ -392,6 +414,7 @@ void SrtpTransport::SetMetricsObserver( if (recv_rtcp_session_) { recv_rtcp_session_->SetMetricsObserver(metrics_observer_); } + rtp_transport_->SetMetricsObserver(metrics_observer); } } // namespace webrtc diff --git a/pc/srtptransport.h b/pc/srtptransport.h index 24c26f7f11..818aabb2a3 100644 --- a/pc/srtptransport.h +++ b/pc/srtptransport.h @@ -17,20 +17,21 @@ #include #include "p2p/base/icetransportinternal.h" -#include "pc/rtptransport.h" +#include "pc/rtptransportinternaladapter.h" #include "pc/srtpfilter.h" #include "pc/srtpsession.h" #include "rtc_base/checks.h" namespace webrtc { -// This subclass of the RtpTransport is used for SRTP which is reponsible for -// protecting/unprotecting the packets. It provides interfaces to set the crypto -// parameters for the SrtpSession underneath. -class SrtpTransport : public RtpTransport { +// This class will eventually be a wrapper around RtpTransportInternal +// that protects and unprotects sent and received RTP packets. +class SrtpTransport : public RtpTransportInternalAdapter { public: explicit SrtpTransport(bool rtcp_mux_enabled); + explicit SrtpTransport(std::unique_ptr rtp_transport); + bool SendRtpPacket(rtc::CopyOnWriteBuffer* packet, const rtc::PacketOptions& options, int flags) override; @@ -39,16 +40,6 @@ class SrtpTransport : public RtpTransport { const rtc::PacketOptions& options, int flags) override; - // SrtpTransportInterface override. - // TODO(zhihuang): Implement these methods and replace the RtpTransportAdapter - // object. - RTCError SetSrtpSendKey(const cricket::CryptoParams& params) override { - return RTCError::OK(); - } - RTCError SetSrtpReceiveKey(const cricket::CryptoParams& params) override { - return RTCError::OK(); - } - // The transport becomes active if the send_session_ and recv_session_ are // created. bool IsActive() const; @@ -114,12 +105,22 @@ class SrtpTransport : public RtpTransport { void ConnectToRtpTransport(); void CreateSrtpSessions(); - void OnRtpPacketReceived(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time) override; - void OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time) override; - void OnNetworkRouteChanged( - rtc::Optional network_route) override; + bool SendPacket(bool rtcp, + rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options, + int flags); + + void OnPacketReceived(bool rtcp, + rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time); + void OnReadyToSend(bool ready) { SignalReadyToSend(ready); } + void OnNetworkRouteChanged(rtc::Optional network_route); + + void OnWritableState(bool writable) { SignalWritableState(writable); } + + void OnSentPacket(const rtc::SentPacket& sent_packet) { + SignalSentPacket(sent_packet); + } bool ProtectRtp(void* data, int in_len, int max_len, int* out_len); @@ -138,6 +139,7 @@ class SrtpTransport : public RtpTransport { bool UnprotectRtcp(void* data, int in_len, int* out_len); const std::string content_name_; + std::unique_ptr rtp_transport_; std::unique_ptr send_session_; std::unique_ptr recv_session_; diff --git a/pc/srtptransport_unittest.cc b/pc/srtptransport_unittest.cc index 9587e0ff50..e0b2302e6f 100644 --- a/pc/srtptransport_unittest.cc +++ b/pc/srtptransport_unittest.cc @@ -42,6 +42,8 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { protected: SrtpTransportTest() { bool rtcp_mux_enabled = true; + auto rtp_transport1 = rtc::MakeUnique(rtcp_mux_enabled); + auto rtp_transport2 = rtc::MakeUnique(rtcp_mux_enabled); rtp_packet_transport1_ = rtc::MakeUnique("fake_packet_transport1"); @@ -52,32 +54,38 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { rtp_packet_transport1_->SetDestination(rtp_packet_transport2_.get(), asymmetric); - srtp_transport1_ = rtc::MakeUnique(rtcp_mux_enabled); - srtp_transport2_ = rtc::MakeUnique(rtcp_mux_enabled); + rtp_transport1->SetRtpPacketTransport(rtp_packet_transport1_.get()); + rtp_transport2->SetRtpPacketTransport(rtp_packet_transport2_.get()); - srtp_transport1_->SetRtpPacketTransport(rtp_packet_transport1_.get()); - srtp_transport2_->SetRtpPacketTransport(rtp_packet_transport2_.get()); + // Add payload type for RTP packet and RTCP packet. + rtp_transport1->AddHandledPayloadType(0x00); + rtp_transport2->AddHandledPayloadType(0x00); + rtp_transport1->AddHandledPayloadType(0xc9); + rtp_transport2->AddHandledPayloadType(0xc9); - srtp_transport1_->SignalRtcpPacketReceived.connect( - &rtp_sink1_, &TransportObserver::OnRtcpPacketReceived); - srtp_transport2_->SignalRtcpPacketReceived.connect( - &rtp_sink2_, &TransportObserver::OnRtcpPacketReceived); + srtp_transport1_ = + rtc::MakeUnique(std::move(rtp_transport1)); + srtp_transport2_ = + rtc::MakeUnique(std::move(rtp_transport2)); - RtpDemuxerCriteria demuxer_criteria; - // 0x00 is the payload type used in kPcmuFrame. - demuxer_criteria.payload_types = {0x00}; - - srtp_transport1_->RegisterRtpDemuxerSink(demuxer_criteria, &rtp_sink1_); - srtp_transport2_->RegisterRtpDemuxerSink(demuxer_criteria, &rtp_sink2_); + srtp_transport1_->SignalPacketReceived.connect( + this, &SrtpTransportTest::OnPacketReceived1); + srtp_transport2_->SignalPacketReceived.connect( + this, &SrtpTransportTest::OnPacketReceived2); } - ~SrtpTransportTest() { - if (srtp_transport1_) { - srtp_transport1_->UnregisterRtpDemuxerSink(&rtp_sink1_); - } - if (srtp_transport2_) { - srtp_transport2_->UnregisterRtpDemuxerSink(&rtp_sink2_); - } + void OnPacketReceived1(bool rtcp, + rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time) { + RTC_LOG(LS_INFO) << "SrtpTransport1 Received a packet."; + last_recv_packet1_ = *packet; + } + + void OnPacketReceived2(bool rtcp, + rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time) { + RTC_LOG(LS_INFO) << "SrtpTransport2 Received a packet."; + last_recv_packet2_ = *packet; } // With external auth enabled, SRTP doesn't write the auth tag and @@ -134,9 +142,9 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { if (srtp_transport1_->IsExternalAuthActive()) { TestRtpAuthParams(srtp_transport1_.get(), cipher_suite_name); } else { - ASSERT_TRUE(rtp_sink2_.last_recv_rtp_packet().data()); - EXPECT_EQ(0, memcmp(rtp_sink2_.last_recv_rtp_packet().data(), - original_rtp_data, rtp_len)); + ASSERT_TRUE(last_recv_packet2_.data()); + EXPECT_EQ(0, + memcmp(last_recv_packet2_.data(), original_rtp_data, rtp_len)); // Get the encrypted packet from underneath packet transport and verify // the data is actually encrypted. auto fake_rtp_packet_transport = static_cast( @@ -151,9 +159,9 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { if (srtp_transport2_->IsExternalAuthActive()) { TestRtpAuthParams(srtp_transport2_.get(), cipher_suite_name); } else { - ASSERT_TRUE(rtp_sink1_.last_recv_rtp_packet().data()); - EXPECT_EQ(0, memcmp(rtp_sink1_.last_recv_rtp_packet().data(), - original_rtp_data, rtp_len)); + ASSERT_TRUE(last_recv_packet1_.data()); + EXPECT_EQ(0, + memcmp(last_recv_packet1_.data(), original_rtp_data, rtp_len)); auto fake_rtp_packet_transport = static_cast( srtp_transport2_->rtp_packet_transport()); EXPECT_NE(0, memcmp(fake_rtp_packet_transport->last_sent_packet()->data(), @@ -162,12 +170,12 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { } void TestSendRecvRtcpPacket(const std::string& cipher_suite_name) { - size_t rtcp_len = sizeof(::kRtcpReport); + size_t rtcp_len = sizeof(kRtcpReport); size_t packet_size = rtcp_len + 4 + rtc::rtcp_auth_tag_len(cipher_suite_name); rtc::Buffer rtcp_packet_buffer(packet_size); char* rtcp_packet_data = rtcp_packet_buffer.data(); - memcpy(rtcp_packet_data, ::kRtcpReport, rtcp_len); + memcpy(rtcp_packet_data, kRtcpReport, rtcp_len); rtc::CopyOnWriteBuffer rtcp_packet1to2(rtcp_packet_data, rtcp_len, packet_size); @@ -179,9 +187,8 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { // that the packet can be successfully received and decrypted. ASSERT_TRUE(srtp_transport1_->SendRtcpPacket(&rtcp_packet1to2, options, cricket::PF_SRTP_BYPASS)); - ASSERT_TRUE(rtp_sink2_.last_recv_rtcp_packet().data()); - EXPECT_EQ(0, memcmp(rtp_sink2_.last_recv_rtcp_packet().data(), - rtcp_packet_data, rtcp_len)); + ASSERT_TRUE(last_recv_packet2_.data()); + EXPECT_EQ(0, memcmp(last_recv_packet2_.data(), rtcp_packet_data, rtcp_len)); // Get the encrypted packet from underneath packet transport and verify the // data is actually encrypted. auto fake_rtp_packet_transport = static_cast( @@ -192,9 +199,8 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { // Do the same thing in the opposite direction; ASSERT_TRUE(srtp_transport2_->SendRtcpPacket(&rtcp_packet2to1, options, cricket::PF_SRTP_BYPASS)); - ASSERT_TRUE(rtp_sink1_.last_recv_rtcp_packet().data()); - EXPECT_EQ(0, memcmp(rtp_sink1_.last_recv_rtcp_packet().data(), - rtcp_packet_data, rtcp_len)); + ASSERT_TRUE(last_recv_packet1_.data()); + EXPECT_EQ(0, memcmp(last_recv_packet1_.data(), rtcp_packet_data, rtcp_len)); fake_rtp_packet_transport = static_cast( srtp_transport2_->rtp_packet_transport()); EXPECT_NE(0, memcmp(fake_rtp_packet_transport->last_sent_packet()->data(), @@ -261,9 +267,8 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { // that the packet can be successfully received and decrypted. ASSERT_TRUE(srtp_transport1_->SendRtpPacket(&rtp_packet1to2, options, cricket::PF_SRTP_BYPASS)); - ASSERT_TRUE(rtp_sink2_.last_recv_rtp_packet().data()); - EXPECT_EQ(0, memcmp(rtp_sink2_.last_recv_rtp_packet().data(), - original_rtp_data, rtp_len)); + ASSERT_TRUE(last_recv_packet2_.data()); + EXPECT_EQ(0, memcmp(last_recv_packet2_.data(), original_rtp_data, rtp_len)); // Get the encrypted packet from underneath packet transport and verify the // data and header extension are actually encrypted. auto fake_rtp_packet_transport = static_cast( @@ -279,9 +284,8 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { // Do the same thing in the opposite direction; ASSERT_TRUE(srtp_transport2_->SendRtpPacket(&rtp_packet2to1, options, cricket::PF_SRTP_BYPASS)); - ASSERT_TRUE(rtp_sink1_.last_recv_rtp_packet().data()); - EXPECT_EQ(0, memcmp(rtp_sink1_.last_recv_rtp_packet().data(), - original_rtp_data, rtp_len)); + ASSERT_TRUE(last_recv_packet1_.data()); + EXPECT_EQ(0, memcmp(last_recv_packet1_.data(), original_rtp_data, rtp_len)); fake_rtp_packet_transport = static_cast( srtp_transport2_->rtp_packet_transport()); EXPECT_NE(0, memcmp(fake_rtp_packet_transport->last_sent_packet()->data(), @@ -324,9 +328,8 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { std::unique_ptr rtp_packet_transport1_; std::unique_ptr rtp_packet_transport2_; - TransportObserver rtp_sink1_; - TransportObserver rtp_sink2_; - + rtc::CopyOnWriteBuffer last_recv_packet1_; + rtc::CopyOnWriteBuffer last_recv_packet2_; int sequence_number_ = 0; }; diff --git a/pc/transportcontroller.cc b/pc/transportcontroller.cc index ee6e53afd0..4e20981da6 100644 --- a/pc/transportcontroller.cc +++ b/pc/transportcontroller.cc @@ -458,12 +458,16 @@ webrtc::DtlsSrtpTransport* TransportController::CreateDtlsSrtpTransport( return existing_rtp_transport->dtls_srtp_transport; } - auto new_dtls_srtp_transport = - rtc::MakeUnique(rtcp_mux_enabled); + auto new_srtp_transport = + rtc::MakeUnique(rtcp_mux_enabled); + #if defined(ENABLE_EXTERNAL_AUTH) - new_dtls_srtp_transport->EnableExternalAuth(); + new_srtp_transport->EnableExternalAuth(); #endif + auto new_dtls_srtp_transport = + rtc::MakeUnique(std::move(new_srtp_transport)); + auto rtp_dtls_transport = CreateDtlsTransport_n( transport_name, cricket::ICE_CANDIDATE_COMPONENT_RTP); auto rtcp_dtls_transport =