diff --git a/media/base/rtputils.cc b/media/base/rtputils.cc index d0ba1cf72b..7cf2c1ba7e 100644 --- a/media/base/rtputils.cc +++ b/media/base/rtputils.cc @@ -275,6 +275,16 @@ bool IsRtpPacket(const void* data, size_t len) { return (static_cast(data)[0] >> 6) == kRtpVersion; } +// Check the RTP payload type. If 63 < payload type < 96, it's RTCP. +// For additional details, see http://tools.ietf.org/html/rfc5761. +bool IsRtcp(const char* data, int len) { + if (len < 2) { + return false; + } + char pt = data[1] & 0x7F; + return (63 < pt) && (pt < 96); +} + bool IsValidRtpPayloadType(int payload_type) { return payload_type >= 0 && payload_type <= 127; } diff --git a/media/base/rtputils.h b/media/base/rtputils.h index 0b7205cf8f..531a2cfeb1 100644 --- a/media/base/rtputils.h +++ b/media/base/rtputils.h @@ -55,6 +55,7 @@ bool SetRtpHeader(void* data, size_t len, const RtpHeader& header); bool IsRtpPacket(const void* data, size_t len); +bool IsRtcp(const char* data, int len); // True if |payload type| is 0-127. bool IsValidRtpPayloadType(int payload_type); diff --git a/pc/BUILD.gn b/pc/BUILD.gn index 907d199168..83bc88bd48 100644 --- a/pc/BUILD.gn +++ b/pc/BUILD.gn @@ -30,8 +30,6 @@ rtc_static_library("rtc_pc_base") { defines = [] sources = [ "audiomonitor.h", - "bundlefilter.cc", - "bundlefilter.h", "channel.cc", "channel.h", "channelmanager.cc", @@ -80,10 +78,13 @@ rtc_static_library("rtc_pc_base") { "../api:optional", "../api:ortc_api", "../api:video_frame_api", + "../call:rtp_interfaces", + "../call:rtp_receiver", "../common_video:common_video", "../media:rtc_data", "../media:rtc_h264_profile_id", "../media:rtc_media_base", + "../modules/rtp_rtcp:rtp_rtcp_format", "../p2p:rtc_p2p", "../rtc_base:checks", "../rtc_base:rtc_base", @@ -274,7 +275,6 @@ if (rtc_include_tests) { testonly = true sources = [ - "bundlefilter_unittest.cc", "channel_unittest.cc", "channelmanager_unittest.cc", "currentspeakermonitor_unittest.cc", @@ -314,9 +314,11 @@ if (rtc_include_tests) { "../api:array_view", "../api:fakemetricsobserver", "../api:libjingle_peerconnection_api", + "../call:rtp_interfaces", "../logging:rtc_event_log_api", "../media:rtc_media_base", "../media:rtc_media_tests_utils", + "../modules/rtp_rtcp:rtp_rtcp_format", "../p2p:p2p_test_utils", "../p2p:rtc_p2p", "../rtc_base:checks", diff --git a/pc/bundlefilter.cc b/pc/bundlefilter.cc deleted file mode 100644 index 7791da6274..0000000000 --- a/pc/bundlefilter.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright 2004 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "pc/bundlefilter.h" - -#include "media/base/rtputils.h" -#include "rtc_base/logging.h" - -namespace cricket { - -BundleFilter::BundleFilter() { -} - -BundleFilter::~BundleFilter() { -} - -bool BundleFilter::DemuxPacket(const uint8_t* data, size_t len) { - // For RTP packets, we check whether the payload type can be found. - if (!IsRtpPacket(data, len)) { - return false; - } - - int payload_type = 0; - if (!GetRtpPayloadType(data, len, &payload_type)) { - return false; - } - return FindPayloadType(payload_type); -} - -void BundleFilter::AddPayloadType(int payload_type) { - payload_types_.insert(payload_type); -} - -bool BundleFilter::FindPayloadType(int pl_type) const { - return payload_types_.find(pl_type) != payload_types_.end(); -} - -void BundleFilter::ClearAllPayloadTypes() { - payload_types_.clear(); -} - -} // namespace cricket diff --git a/pc/bundlefilter.h b/pc/bundlefilter.h deleted file mode 100644 index 7decbba8a4..0000000000 --- a/pc/bundlefilter.h +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright 2004 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef PC_BUNDLEFILTER_H_ -#define PC_BUNDLEFILTER_H_ - -#include - -#include -#include - -#include "media/base/streamparams.h" -#include "rtc_base/basictypes.h" - -namespace cricket { - -// In case of single RTP session and single transport channel, all session -// (or media) channels share a common transport channel. Hence they all get -// SignalReadPacket when packet received on transport channel. This requires -// cricket::BaseChannel to know all the valid sources, else media channel -// will decode invalid packets. -// -// This class determines whether a packet is destined for cricket::BaseChannel. -// This is only to be used for RTP packets as RTCP packets are not filtered. -// For RTP packets, this is decided based on the payload type. -class BundleFilter { - public: - BundleFilter(); - ~BundleFilter(); - - // Determines if a RTP packet belongs to valid cricket::BaseChannel. - bool DemuxPacket(const uint8_t* data, size_t len); - - // Adds the supported payload type. - void AddPayloadType(int payload_type); - - // Public for unittests. - bool FindPayloadType(int pl_type) const; - void ClearAllPayloadTypes(); - - private: - std::set payload_types_; -}; - -} // namespace cricket - -#endif // PC_BUNDLEFILTER_H_ diff --git a/pc/bundlefilter_unittest.cc b/pc/bundlefilter_unittest.cc deleted file mode 100644 index 2b1af5c3a3..0000000000 --- a/pc/bundlefilter_unittest.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright 2004 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "pc/bundlefilter.h" -#include "rtc_base/gunit.h" - -using cricket::StreamParams; - -static const int kPayloadType1 = 0x11; -static const int kPayloadType2 = 0x22; -static const int kPayloadType3 = 0x33; - -// SSRC = 0x1111, Payload type = 0x11 -static const unsigned char kRtpPacketPt1Ssrc1[] = { - 0x80, kPayloadType1, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, - 0x11, -}; - -// SSRC = 0x2222, Payload type = 0x22 -static const unsigned char kRtpPacketPt2Ssrc2[] = { - 0x80, 0x80 + kPayloadType2, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x22, 0x22, -}; - -// SSRC = 0x2222, Payload type = 0x33 -static const unsigned char kRtpPacketPt3Ssrc2[] = { - 0x80, kPayloadType3, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x22, - 0x22, -}; - -// An SCTP packet. -static const unsigned char kSctpPacket[] = { - 0x00, 0x01, 0x00, 0x01, - 0xff, 0xff, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x04, - 0x00, 0x00, 0x00, 0x00, -}; - -TEST(BundleFilterTest, RtpPacketTest) { - cricket::BundleFilter bundle_filter; - bundle_filter.AddPayloadType(kPayloadType1); - EXPECT_TRUE(bundle_filter.DemuxPacket(kRtpPacketPt1Ssrc1, - sizeof(kRtpPacketPt1Ssrc1))); - bundle_filter.AddPayloadType(kPayloadType2); - EXPECT_TRUE(bundle_filter.DemuxPacket(kRtpPacketPt2Ssrc2, - sizeof(kRtpPacketPt2Ssrc2))); - - // Payload type 0x33 is not added. - EXPECT_FALSE(bundle_filter.DemuxPacket(kRtpPacketPt3Ssrc2, - sizeof(kRtpPacketPt3Ssrc2))); - // Size is too small. - EXPECT_FALSE(bundle_filter.DemuxPacket(kRtpPacketPt1Ssrc1, 11)); - - bundle_filter.ClearAllPayloadTypes(); - EXPECT_FALSE(bundle_filter.DemuxPacket(kRtpPacketPt1Ssrc1, - sizeof(kRtpPacketPt1Ssrc1))); - EXPECT_FALSE(bundle_filter.DemuxPacket(kRtpPacketPt2Ssrc2, - sizeof(kRtpPacketPt2Ssrc2))); -} - -TEST(BundleFilterTest, InvalidRtpPacket) { - cricket::BundleFilter bundle_filter; - EXPECT_FALSE(bundle_filter.DemuxPacket(kSctpPacket, sizeof(kSctpPacket))); -} diff --git a/pc/channel.cc b/pc/channel.cc index 358cb77e55..53a44b62ce 100644 --- a/pc/channel.cc +++ b/pc/channel.cc @@ -29,6 +29,7 @@ // Adding 'nogncheck' to disable the gn include headers check to support modular // WebRTC build targets. #include "media/engine/webrtcvoiceengine.h" // nogncheck +#include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "p2p/base/packettransportinternal.h" #include "pc/channelmanager.h" #include "pc/rtpmediautils.h" @@ -110,6 +111,7 @@ BaseChannel::BaseChannel(rtc::Thread* worker_thread, srtp_required_(srtp_required), media_channel_(std::move(media_channel)) { RTC_DCHECK_RUN_ON(worker_thread_); + demuxer_criteria_.mid = content_name; rtp_transport_ = unencrypted_rtp_transport_.get(); ConnectToRtpTransport(); RTC_LOG(LS_INFO) << "Created channel for " << content_name; @@ -131,13 +133,11 @@ BaseChannel::~BaseChannel() { void BaseChannel::ConnectToRtpTransport() { RTC_DCHECK(rtp_transport_); + RTC_DCHECK(RegisterRtpDemuxerSink()); rtp_transport_->SignalReadyToSend.connect( this, &BaseChannel::OnTransportReadyToSend); - // TODO(zstein): RtpTransport::SignalPacketReceived will probably be replaced - // with a callback interface later so that the demuxer can select which - // channel to signal. - rtp_transport_->SignalPacketReceived.connect(this, - &BaseChannel::OnPacketReceived); + rtp_transport_->SignalRtcpPacketReceived.connect( + this, &BaseChannel::OnRtcpPacketReceived); rtp_transport_->SignalNetworkRouteChanged.connect( this, &BaseChannel::OnNetworkRouteChanged); rtp_transport_->SignalWritableState.connect(this, @@ -154,8 +154,9 @@ void BaseChannel::ConnectToRtpTransport() { void BaseChannel::DisconnectFromRtpTransport() { RTC_DCHECK(rtp_transport_); + rtp_transport_->UnregisterRtpDemuxerSink(this); rtp_transport_->SignalReadyToSend.disconnect(this); - rtp_transport_->SignalPacketReceived.disconnect(this); + rtp_transport_->SignalRtcpPacketReceived.disconnect(this); rtp_transport_->SignalNetworkRouteChanged.disconnect(this); rtp_transport_->SignalWritableState.disconnect(this); rtp_transport_->SignalSentPacket.disconnect(this); @@ -203,17 +204,26 @@ void BaseChannel::Deinit() { // functions, so need to stop this process in Deinit that is called in // derived classes destructor. network_thread_->Invoke(RTC_FROM_HERE, [&] { - FlushRtcpMessages_n(); - - if (dtls_srtp_transport_) { - dtls_srtp_transport_->SetDtlsTransports(nullptr, nullptr); - } else { - rtp_transport_->SetRtpPacketTransport(nullptr); - rtp_transport_->SetRtcpPacketTransport(nullptr); + if (rtp_transport_) { + FlushRtcpMessages_n(); + if (dtls_srtp_transport_) { + dtls_srtp_transport_->SetDtlsTransports(nullptr, nullptr); + } else { + rtp_transport_->SetRtpPacketTransport(nullptr); + rtp_transport_->SetRtcpPacketTransport(nullptr); + } + DisconnectFromRtpTransport(); } + // Clear pending read packets/messages. network_thread_->Clear(&invoker_); network_thread_->Clear(this); + // Because RTP level transports are accessed from the |network_thread_|, + // it's safer to release them from the |network_thread_| as well. + unencrypted_rtp_transport_.reset(); + sdes_transport_.reset(); + dtls_srtp_transport_.reset(); + rtp_transport_ = nullptr; }); } @@ -224,8 +234,10 @@ void BaseChannel::SetRtpTransport(webrtc::RtpTransportInternal* rtp_transport) { return; }); } - RTC_DCHECK(rtp_transport); + if (rtp_transport == rtp_transport_) { + return; + } if (rtp_transport_) { DisconnectFromRtpTransport(); @@ -581,12 +593,37 @@ bool BaseChannel::SendPacket(bool rtcp, : rtp_transport_->SendRtpPacket(packet, options, PF_SRTP_BYPASS); } -bool BaseChannel::HandlesPayloadType(int packet_type) const { - return rtp_transport_->HandlesPayloadType(packet_type); +void BaseChannel::OnRtpPacket(const webrtc::RtpPacketReceived& parsed_packet) { + // Reconstruct the PacketTime from the |parsed_packet|. + // RtpPacketReceived.arrival_time_ms = (PacketTime + 500) / 1000; + // Note: The |not_before| field is always 0 here. This field is not currently + // used, so it should be fine. + int64_t timestamp = -1; + if (parsed_packet.arrival_time_ms() > 0) { + timestamp = parsed_packet.arrival_time_ms() * 1000; + } + rtc::PacketTime packet_time(timestamp, /*not_before=*/0); + OnPacketReceived(/*rtcp=*/false, parsed_packet.Buffer(), packet_time); +} + +void BaseChannel::UpdateRtpHeaderExtensionMap( + const RtpHeaderExtensions& header_extensions) { + RTC_DCHECK(rtp_transport_); + rtp_transport_->UpdateRtpHeaderExtensionMap(header_extensions); +} + +bool BaseChannel::RegisterRtpDemuxerSink() { + RTC_DCHECK(rtp_transport_); + return rtp_transport_->RegisterRtpDemuxerSink(demuxer_criteria_, this); +} + +void BaseChannel::OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time) { + OnPacketReceived(/*rtcp=*/true, *packet, packet_time); } void BaseChannel::OnPacketReceived(bool rtcp, - rtc::CopyOnWriteBuffer* packet, + const rtc::CopyOnWriteBuffer& packet, const rtc::PacketTime& packet_time) { if (!has_received_packet_ && !rtcp) { has_received_packet_ = true; @@ -613,7 +650,7 @@ void BaseChannel::OnPacketReceived(bool rtcp, invoker_.AsyncInvoke( RTC_FROM_HERE, worker_thread_, - Bind(&BaseChannel::ProcessPacket, this, rtcp, *packet, packet_time)); + Bind(&BaseChannel::ProcessPacket, this, rtcp, packet, packet_time)); } void BaseChannel::ProcessPacket(bool rtcp, @@ -755,14 +792,17 @@ void BaseChannel::EnableSdes_n() { // DtlsSrtpTransport and SrtpTransport shouldn't be enabled at the same // time. RTC_DCHECK(!dtls_srtp_transport_); - RTC_DCHECK(unencrypted_rtp_transport_); - sdes_transport_ = rtc::MakeUnique( - std::move(unencrypted_rtp_transport_)); + + sdes_transport_ = rtc::MakeUnique(rtcp_mux_required_); #if defined(ENABLE_EXTERNAL_AUTH) sdes_transport_->EnableExternalAuth(); #endif + sdes_transport_->SetRtpPacketTransport( + rtp_transport_->rtp_packet_transport()); + sdes_transport_->SetRtcpPacketTransport( + rtp_transport_->rtcp_packet_transport()); SetRtpTransport(sdes_transport_.get()); - RTC_LOG(LS_INFO) << "Wrapping RtpTransport in SrtpTransport."; + RTC_LOG(LS_INFO) << "SrtpTransport is created for SDES."; } void BaseChannel::EnableDtlsSrtp_n() { @@ -772,15 +812,12 @@ void BaseChannel::EnableDtlsSrtp_n() { // DtlsSrtpTransport and SrtpTransport shouldn't be enabled at the same // time. RTC_DCHECK(!sdes_transport_); - RTC_DCHECK(unencrypted_rtp_transport_); - auto srtp_transport = rtc::MakeUnique( - std::move(unencrypted_rtp_transport_)); -#if defined(ENABLE_EXTERNAL_AUTH) - srtp_transport->EnableExternalAuth(); -#endif dtls_srtp_transport_ = - rtc::MakeUnique(std::move(srtp_transport)); + rtc::MakeUnique(rtcp_mux_required_); +#if defined(ENABLE_EXTERNAL_AUTH) + dtls_srtp_transport_->EnableExternalAuth(); +#endif SetRtpTransport(dtls_srtp_transport_.get()); if (cached_send_extension_ids_) { @@ -796,8 +833,7 @@ void BaseChannel::EnableDtlsSrtp_n() { RTC_DCHECK(rtp_dtls_transport_); dtls_srtp_transport_->SetDtlsTransports(rtp_dtls_transport_, rtcp_dtls_transport_); - - RTC_LOG(LS_INFO) << "Wrapping SrtpTransport in DtlsSrtpTransport."; + RTC_LOG(LS_INFO) << "DtlsSrtpTransport is created for DTLS-SRTP."; } bool BaseChannel::SetSrtp_n(const std::vector& cryptos, @@ -1077,7 +1113,7 @@ void BaseChannel::OnMessage(rtc::Message *pmsg) { } void BaseChannel::AddHandledPayloadType(int payload_type) { - rtp_transport_->AddHandledPayloadType(payload_type); + demuxer_criteria_.payload_types.insert(static_cast(payload_type)); } void BaseChannel::FlushRtcpMessages_n() { @@ -1207,6 +1243,7 @@ bool VoiceChannel::SetLocalContent_w(const MediaContentDescription* content, RtpHeaderExtensions rtp_header_extensions = GetFilteredRtpHeaderExtensions(audio->rtp_header_extensions()); + UpdateRtpHeaderExtensionMap(rtp_header_extensions); if (!SetRtpTransportParameters(content, type, CS_LOCAL, rtp_header_extensions, error_desc)) { @@ -1220,9 +1257,16 @@ bool VoiceChannel::SetLocalContent_w(const MediaContentDescription* content, error_desc); return false; } + for (const AudioCodec& codec : audio->codecs()) { AddHandledPayloadType(codec.id); } + // Need to re-register the sink to update the handled payload. + if (!RegisterRtpDemuxerSink()) { + RTC_LOG(LS_ERROR) << "Failed to set up audio demuxing."; + return false; + } + last_recv_params_ = recv_params; // TODO(pthatcher): Move local streams into AudioSendParameters, and @@ -1257,6 +1301,7 @@ bool VoiceChannel::SetRemoteContent_w(const MediaContentDescription* content, RtpHeaderExtensions rtp_header_extensions = GetFilteredRtpHeaderExtensions(audio->rtp_header_extensions()); + UpdateRtpHeaderExtensionMap(rtp_header_extensions); if (!SetRtpTransportParameters(content, type, CS_REMOTE, rtp_header_extensions, error_desc)) { return false; @@ -1312,7 +1357,6 @@ VideoChannel::~VideoChannel() { TRACE_EVENT0("webrtc", "VideoChannel::~VideoChannel"); // this can't be done in the base class, since it calls a virtual DisableMedia_w(); - Deinit(); } @@ -1350,6 +1394,7 @@ bool VideoChannel::SetLocalContent_w(const MediaContentDescription* content, RtpHeaderExtensions rtp_header_extensions = GetFilteredRtpHeaderExtensions(video->rtp_header_extensions()); + UpdateRtpHeaderExtensionMap(rtp_header_extensions); if (!SetRtpTransportParameters(content, type, CS_LOCAL, rtp_header_extensions, error_desc)) { @@ -1363,9 +1408,16 @@ bool VideoChannel::SetLocalContent_w(const MediaContentDescription* content, error_desc); return false; } + for (const VideoCodec& codec : video->codecs()) { AddHandledPayloadType(codec.id); } + // Need to re-register the sink to update the handled payload. + if (!RegisterRtpDemuxerSink()) { + RTC_LOG(LS_ERROR) << "Failed to set up video demuxing."; + return false; + } + last_recv_params_ = recv_params; // TODO(pthatcher): Move local streams into VideoSendParameters, and @@ -1400,6 +1452,7 @@ bool VideoChannel::SetRemoteContent_w(const MediaContentDescription* content, RtpHeaderExtensions rtp_header_extensions = GetFilteredRtpHeaderExtensions(video->rtp_header_extensions()); + UpdateRtpHeaderExtensionMap(rtp_header_extensions); if (!SetRtpTransportParameters(content, type, CS_REMOTE, rtp_header_extensions, error_desc)) { return false; @@ -1459,7 +1512,6 @@ RtpDataChannel::~RtpDataChannel() { TRACE_EVENT0("webrtc", "RtpDataChannel::~RtpDataChannel"); // this can't be done in the base class, since it calls a virtual DisableMedia_w(); - Deinit(); } @@ -1528,6 +1580,7 @@ bool RtpDataChannel::SetLocalContent_w(const MediaContentDescription* content, RtpHeaderExtensions rtp_header_extensions = GetFilteredRtpHeaderExtensions(data->rtp_header_extensions()); + UpdateRtpHeaderExtensionMap(rtp_header_extensions); if (!SetRtpTransportParameters(content, type, CS_LOCAL, rtp_header_extensions, error_desc)) { @@ -1541,9 +1594,16 @@ bool RtpDataChannel::SetLocalContent_w(const MediaContentDescription* content, error_desc); return false; } + for (const DataCodec& codec : data->codecs()) { AddHandledPayloadType(codec.id); } + // Need to re-register the sink to update the handled payload. + if (!RegisterRtpDemuxerSink()) { + RTC_LOG(LS_ERROR) << "Failed to set up data demuxing."; + return false; + } + last_recv_params_ = recv_params; // TODO(pthatcher): Move local streams into DataSendParameters, and @@ -1587,6 +1647,7 @@ bool RtpDataChannel::SetRemoteContent_w(const MediaContentDescription* content, RtpHeaderExtensions rtp_header_extensions = GetFilteredRtpHeaderExtensions(data->rtp_header_extensions()); + UpdateRtpHeaderExtensionMap(rtp_header_extensions); RTC_LOG(LS_INFO) << "Setting remote data description"; if (!SetRtpTransportParameters(content, type, CS_REMOTE, rtp_header_extensions, error_desc)) { diff --git a/pc/channel.h b/pc/channel.h index 6a8367a754..52474c3a14 100644 --- a/pc/channel.h +++ b/pc/channel.h @@ -23,6 +23,7 @@ #include "api/rtpreceiverinterface.h" #include "api/videosinkinterface.h" #include "api/videosourceinterface.h" +#include "call/rtp_packet_sink_interface.h" #include "media/base/mediachannel.h" #include "media/base/mediaengine.h" #include "media/base/streamparams.h" @@ -69,9 +70,10 @@ class MediaContentDescription; // vtable, and the media channel's thread using BaseChannel as the // NetworkInterface. -class BaseChannel - : public rtc::MessageHandler, public sigslot::has_slots<>, - public MediaChannel::NetworkInterface { +class BaseChannel : public rtc::MessageHandler, + public sigslot::has_slots<>, + public MediaChannel::NetworkInterface, + public webrtc::RtpPacketSinkInterface { public: // If |srtp_required| is true, the channel will not send or receive any // RTP/RTCP packets without using SRTP (either using SDES or DTLS-SRTP). @@ -193,10 +195,8 @@ class BaseChannel virtual cricket::MediaType media_type() = 0; - // Public for testing. - // TODO(zstein): Remove this once channels register themselves with - // an RtpTransport in a more explicit way. - bool HandlesPayloadType(int payload_type) const; + // RtpPacketSinkInterface overrides. + void OnRtpPacket(const webrtc::RtpPacketReceived& packet) override; // Used by the RTCStatsCollector tests to set the transport name without // creating RtpTransports. @@ -264,12 +264,10 @@ class BaseChannel rtc::CopyOnWriteBuffer* packet, const rtc::PacketOptions& options); - bool WantsPacket(bool rtcp, const rtc::CopyOnWriteBuffer* packet); - void HandlePacket(bool rtcp, rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time); - // TODO(zstein): packet can be const once the RtpTransport handles protection. + void OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time); void OnPacketReceived(bool rtcp, - rtc::CopyOnWriteBuffer* packet, + const rtc::CopyOnWriteBuffer& packet, const rtc::PacketTime& packet_time); void ProcessPacket(bool rtcp, const rtc::CopyOnWriteBuffer& packet, @@ -360,6 +358,11 @@ class BaseChannel void AddHandledPayloadType(int payload_type); + void UpdateRtpHeaderExtensionMap( + const RtpHeaderExtensions& header_extensions); + + bool RegisterRtpDemuxerSink(); + private: void ConnectToRtpTransport(); void DisconnectFromRtpTransport(); @@ -439,6 +442,9 @@ class BaseChannel // The cached encrypted header extension IDs. rtc::Optional> cached_send_extension_ids_; rtc::Optional> cached_recv_extension_ids_; + + bool encryption_disabled_ = false; + webrtc::RtpDemuxerCriteria demuxer_criteria_; }; // VoiceChannel is a specialization that adds support for early media, DTMF, diff --git a/pc/channel_unittest.cc b/pc/channel_unittest.cc index 7ee35013ac..753d6cd618 100644 --- a/pc/channel_unittest.cc +++ b/pc/channel_unittest.cc @@ -1609,10 +1609,6 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { EXPECT_TRUE(SendAccept()); EXPECT_EQ(rtcp_mux, !channel1_->NeedsRtcpTransport()); EXPECT_EQ(rtcp_mux, !channel2_->NeedsRtcpTransport()); - EXPECT_TRUE(channel1_->HandlesPayloadType(pl_type1)); - EXPECT_TRUE(channel2_->HandlesPayloadType(pl_type1)); - EXPECT_FALSE(channel1_->HandlesPayloadType(pl_type2)); - EXPECT_FALSE(channel2_->HandlesPayloadType(pl_type2)); // Both channels can receive pl_type1 only. SendCustomRtp1(kSsrc1, ++sequence_number1_1, pl_type1); @@ -1623,13 +1619,15 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { EXPECT_TRUE(CheckNoRtp1()); EXPECT_TRUE(CheckNoRtp2()); - // RTCP test + EXPECT_TRUE(SendInitiate()); + EXPECT_TRUE(SendAccept()); SendCustomRtp1(kSsrc1, ++sequence_number1_1, pl_type2); SendCustomRtp2(kSsrc2, ++sequence_number2_2, pl_type2); WaitForThreads(); EXPECT_FALSE(CheckCustomRtp2(kSsrc1, sequence_number1_1, pl_type2)); EXPECT_FALSE(CheckCustomRtp1(kSsrc2, sequence_number2_2, pl_type2)); + // RTCP test SendCustomRtcp1(kSsrc1); SendCustomRtcp2(kSsrc2); WaitForThreads(); diff --git a/pc/channelmanager.cc b/pc/channelmanager.cc index ead1da9cdd..2ae35f88eb 100644 --- a/pc/channelmanager.cc +++ b/pc/channelmanager.cc @@ -507,7 +507,6 @@ void ChannelManager::DestroyRtpDataChannel(RtpDataChannel* data_channel) { if (it == data_channels_.end()) { return; } - data_channels_.erase(it); } diff --git a/pc/channelmanager_unittest.cc b/pc/channelmanager_unittest.cc index d318ac5b20..bf799f6d08 100644 --- a/pc/channelmanager_unittest.cc +++ b/pc/channelmanager_unittest.cc @@ -195,11 +195,14 @@ class ChannelManagerTestWithRtpTransport RTPTransportType type = GetParam(); switch (type) { case RTPTransportType::kRtp: - return CreatePlainRtpTransport(); + return rtc::MakeUnique( + /*rtcp_mux_required=*/true); case RTPTransportType::kSrtp: - return CreateSrtpTransport(); + return rtc::MakeUnique( + /*rtcp_mux_required=*/true); case RTPTransportType::kDtlsSrtp: - return CreateDtlsSrtpTransport(); + return rtc::MakeUnique( + /*rtcp_mux_required=*/true); } return nullptr; } @@ -224,29 +227,6 @@ class ChannelManagerTestWithRtpTransport cm_->DestroyRtpDataChannel(rtp_data_channel); cm_->Terminate(); } - - private: - std::unique_ptr CreatePlainRtpTransport() { - return rtc::MakeUnique(/*rtcp_mux_required=*/true); - } - - std::unique_ptr CreateSrtpTransport() { - auto rtp_transport = - rtc::MakeUnique(/*rtcp_mux_required=*/true); - auto srtp_transport = - rtc::MakeUnique(std::move(rtp_transport)); - return srtp_transport; - } - - std::unique_ptr CreateDtlsSrtpTransport() { - auto rtp_transport = - rtc::MakeUnique(/*rtcp_mux_required=*/true); - auto srtp_transport = - rtc::MakeUnique(std::move(rtp_transport)); - auto dtls_srtp_transport_ = - rtc::MakeUnique(std::move(srtp_transport)); - return dtls_srtp_transport_; - } }; TEST_P(ChannelManagerTestWithRtpTransport, CreateDestroyChannels) { diff --git a/pc/dtlssrtptransport.cc b/pc/dtlssrtptransport.cc index 0b98a96293..b85930c56d 100644 --- a/pc/dtlssrtptransport.cc +++ b/pc/dtlssrtptransport.cc @@ -24,22 +24,8 @@ static const char kDtlsSrtpExporterLabel[] = "EXTRACTOR-dtls_srtp"; namespace webrtc { -DtlsSrtpTransport::DtlsSrtpTransport( - std::unique_ptr srtp_transport) - : RtpTransportInternalAdapter(srtp_transport.get()) { - srtp_transport_ = std::move(srtp_transport); - RTC_DCHECK(srtp_transport_); - srtp_transport_->SignalPacketReceived.connect( - this, &DtlsSrtpTransport::OnPacketReceived); - srtp_transport_->SignalReadyToSend.connect(this, - &DtlsSrtpTransport::OnReadyToSend); - srtp_transport_->SignalNetworkRouteChanged.connect( - this, &DtlsSrtpTransport::OnNetworkRouteChanged); - srtp_transport_->SignalWritableState.connect( - this, &DtlsSrtpTransport::OnWritableState); - srtp_transport_->SignalSentPacket.connect(this, - &DtlsSrtpTransport::OnSentPacket); -} +DtlsSrtpTransport::DtlsSrtpTransport(bool rtcp_mux_enabled) + : SrtpTransport(rtcp_mux_enabled) {} void DtlsSrtpTransport::SetDtlsTransports( cricket::DtlsTransportInternal* rtp_dtls_transport, @@ -54,7 +40,7 @@ void DtlsSrtpTransport::SetDtlsTransports( // DtlsTransport changes and wait until the DTLS handshake is complete to set // the newly negotiated parameters. if (IsActive()) { - srtp_transport_->ResetParams(); + ResetParams(); } const std::string transport_name = @@ -80,7 +66,7 @@ void DtlsSrtpTransport::SetDtlsTransports( } void DtlsSrtpTransport::SetRtcpMuxEnabled(bool enable) { - srtp_transport_->SetRtcpMuxEnabled(enable); + SrtpTransport::SetRtcpMuxEnabled(enable); if (enable) { UpdateWritableStateAndMaybeSetupDtlsSrtp(); } @@ -128,10 +114,9 @@ bool DtlsSrtpTransport::IsDtlsConnected() { } bool DtlsSrtpTransport::IsDtlsWritable() { - auto rtp_packet_transport = srtp_transport_->rtp_packet_transport(); auto rtcp_packet_transport = - rtcp_mux_enabled() ? nullptr : srtp_transport_->rtcp_packet_transport(); - return rtp_packet_transport && rtp_packet_transport->writable() && + rtcp_mux_enabled() ? nullptr : rtcp_dtls_transport_; + return rtp_dtls_transport_ && rtp_dtls_transport_->writable() && (!rtcp_packet_transport || rtcp_packet_transport->writable()); } @@ -170,11 +155,10 @@ void DtlsSrtpTransport::SetupRtpDtlsSrtp() { if (!ExtractParams(rtp_dtls_transport_, &selected_crypto_suite, &send_key, &recv_key) || - !srtp_transport_->SetRtpParams( - selected_crypto_suite, &send_key[0], - static_cast(send_key.size()), send_extension_ids, - selected_crypto_suite, &recv_key[0], - static_cast(recv_key.size()), recv_extension_ids)) { + !SetRtpParams(selected_crypto_suite, &send_key[0], + static_cast(send_key.size()), send_extension_ids, + selected_crypto_suite, &recv_key[0], + static_cast(recv_key.size()), recv_extension_ids)) { SignalDtlsSrtpSetupFailure(this, /*rtcp=*/false); RTC_LOG(LS_WARNING) << "DTLS-SRTP key installation for RTP failed"; } @@ -202,11 +186,11 @@ void DtlsSrtpTransport::SetupRtcpDtlsSrtp() { rtc::ZeroOnFreeBuffer rtcp_recv_key; if (!ExtractParams(rtcp_dtls_transport_, &selected_crypto_suite, &rtcp_send_key, &rtcp_recv_key) || - !srtp_transport_->SetRtcpParams( - selected_crypto_suite, &rtcp_send_key[0], - static_cast(rtcp_send_key.size()), send_extension_ids, - selected_crypto_suite, &rtcp_recv_key[0], - static_cast(rtcp_recv_key.size()), recv_extension_ids)) { + !SetRtcpParams(selected_crypto_suite, &rtcp_send_key[0], + static_cast(rtcp_send_key.size()), send_extension_ids, + selected_crypto_suite, &rtcp_recv_key[0], + static_cast(rtcp_recv_key.size()), + recv_extension_ids)) { SignalDtlsSrtpSetupFailure(this, /*rtcp=*/true); RTC_LOG(LS_WARNING) << "DTLS-SRTP key installation for RTCP failed"; } @@ -329,37 +313,20 @@ void DtlsSrtpTransport::OnDtlsState(cricket::DtlsTransportInternal* transport, transport == rtcp_dtls_transport_); if (state != cricket::DTLS_TRANSPORT_CONNECTED) { - srtp_transport_->ResetParams(); + ResetParams(); return; } MaybeSetupDtlsSrtp(); } -void DtlsSrtpTransport::OnWritableState(bool writable) { +void DtlsSrtpTransport::OnWritableState( + rtc::PacketTransportInternal* packet_transport) { + bool writable = IsTransportWritable(); SetWritable(writable); if (writable) { MaybeSetupDtlsSrtp(); } } -void DtlsSrtpTransport::OnSentPacket(const rtc::SentPacket& sent_packet) { - SignalSentPacket(sent_packet); -} - -void DtlsSrtpTransport::OnPacketReceived(bool rtcp, - rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time) { - SignalPacketReceived(rtcp, packet, packet_time); -} - -void DtlsSrtpTransport::OnReadyToSend(bool ready) { - SignalReadyToSend(ready); -} - -void DtlsSrtpTransport::OnNetworkRouteChanged( - rtc::Optional network_route) { - SignalNetworkRouteChanged(network_route); -} - } // namespace webrtc diff --git a/pc/dtlssrtptransport.h b/pc/dtlssrtptransport.h index 02002b052a..fdd54b2615 100644 --- a/pc/dtlssrtptransport.h +++ b/pc/dtlssrtptransport.h @@ -16,20 +16,17 @@ #include #include "p2p/base/dtlstransportinternal.h" -#include "pc/rtptransportinternaladapter.h" #include "pc/srtptransport.h" #include "rtc_base/buffer.h" namespace webrtc { -// This class is intended to be used as an RtpTransport and it wraps both an -// SrtpTransport and DtlsTransports(RTP/RTCP). When the DTLS handshake is -// finished, it extracts the keying materials from DtlsTransport and sets them -// to SrtpTransport. -class DtlsSrtpTransport : public RtpTransportInternalAdapter { +// The subclass of SrtpTransport is used for DTLS-SRTP. When the DTLS handshake +// is finished, it extracts the keying materials from DtlsTransport and +// configures the SrtpSessions in the base class. +class DtlsSrtpTransport : public SrtpTransport { public: - explicit DtlsSrtpTransport( - std::unique_ptr srtp_transport); + explicit DtlsSrtpTransport(bool rtcp_mux_enabled); // Set P2P layer RTP/RTCP DtlsTransports. When using RTCP-muxing, // |rtcp_dtls_transport| is null. @@ -45,15 +42,6 @@ class DtlsSrtpTransport : public RtpTransportInternalAdapter { void UpdateRecvEncryptedHeaderExtensionIds( const std::vector& recv_extension_ids); - bool IsActive() { return srtp_transport_->IsActive(); } - - // Cache RTP Absoulute SendTime extension header ID. This is only used when - // external authentication is enabled. - void CacheRtpAbsSendTimeHeaderExtension(int rtp_abs_sendtime_extn_id) { - srtp_transport_->CacheRtpAbsSendTimeHeaderExtension( - rtp_abs_sendtime_extn_id); - } - // TODO(zhihuang): Remove this when we remove RtpTransportAdapter. RtpTransportAdapter* GetInternal() override { return nullptr; } @@ -83,16 +71,11 @@ class DtlsSrtpTransport : public RtpTransportInternalAdapter { void OnDtlsState(cricket::DtlsTransportInternal* dtls_transport, cricket::DtlsTransportState state); - void OnWritableState(bool writable); - void OnSentPacket(const rtc::SentPacket& sent_packet); - void OnPacketReceived(bool rtcp, - rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time); - void OnReadyToSend(bool ready); - void OnNetworkRouteChanged(rtc::Optional network_route); + + // Override the RtpTransport::OnWritableState. + void OnWritableState(rtc::PacketTransportInternal* packet_transport) override; bool writable_ = false; - std::unique_ptr srtp_transport_; // Owned by the TransportController. cricket::DtlsTransportInternal* rtp_dtls_transport_ = nullptr; cricket::DtlsTransportInternal* rtcp_dtls_transport_ = nullptr; diff --git a/pc/dtlssrtptransport_unittest.cc b/pc/dtlssrtptransport_unittest.cc index 08a8151ee7..eb37b701db 100644 --- a/pc/dtlssrtptransport_unittest.cc +++ b/pc/dtlssrtptransport_unittest.cc @@ -33,50 +33,26 @@ using webrtc::RtpTransport; const int kRtpAuthTagLen = 10; -class TransportObserver : public sigslot::has_slots<> { - public: - void OnPacketReceived(bool rtcp, - rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time) { - rtcp ? last_recv_rtcp_packet_ = *packet : last_recv_rtp_packet_ = *packet; - } - - void OnReadyToSend(bool ready) { ready_to_send_ = ready; } - - rtc::CopyOnWriteBuffer last_recv_rtp_packet() { - return last_recv_rtp_packet_; - } - - rtc::CopyOnWriteBuffer last_recv_rtcp_packet() { - return last_recv_rtcp_packet_; - } - - bool ready_to_send() { return ready_to_send_; } - - private: - rtc::CopyOnWriteBuffer last_recv_rtp_packet_; - rtc::CopyOnWriteBuffer last_recv_rtcp_packet_; - bool ready_to_send_ = false; -}; - class DtlsSrtpTransportTest : public testing::Test, public sigslot::has_slots<> { protected: DtlsSrtpTransportTest() {} + ~DtlsSrtpTransportTest() { + if (dtls_srtp_transport1_) { + dtls_srtp_transport1_->UnregisterRtpDemuxerSink(&transport_observer1_); + } + if (dtls_srtp_transport2_) { + dtls_srtp_transport2_->UnregisterRtpDemuxerSink(&transport_observer2_); + } + } + std::unique_ptr MakeDtlsSrtpTransport( FakeDtlsTransport* rtp_dtls, FakeDtlsTransport* rtcp_dtls, bool rtcp_mux_enabled) { - auto rtp_transport = rtc::MakeUnique(rtcp_mux_enabled); - - rtp_transport->AddHandledPayloadType(0x00); - rtp_transport->AddHandledPayloadType(0xc9); - - auto srtp_transport = - rtc::MakeUnique(std::move(rtp_transport)); auto dtls_srtp_transport = - rtc::MakeUnique(std::move(srtp_transport)); + rtc::MakeUnique(rtcp_mux_enabled); dtls_srtp_transport->SetDtlsTransports(rtp_dtls, rtcp_dtls); @@ -93,15 +69,24 @@ class DtlsSrtpTransportTest : public testing::Test, dtls_srtp_transport2_ = MakeDtlsSrtpTransport(rtp_dtls2, rtcp_dtls2, rtcp_mux_enabled); - dtls_srtp_transport1_->SignalPacketReceived.connect( - &transport_observer1_, &TransportObserver::OnPacketReceived); + dtls_srtp_transport1_->SignalRtcpPacketReceived.connect( + &transport_observer1_, + &webrtc::TransportObserver::OnRtcpPacketReceived); dtls_srtp_transport1_->SignalReadyToSend.connect( - &transport_observer1_, &TransportObserver::OnReadyToSend); + &transport_observer1_, &webrtc::TransportObserver::OnReadyToSend); - dtls_srtp_transport2_->SignalPacketReceived.connect( - &transport_observer2_, &TransportObserver::OnPacketReceived); + dtls_srtp_transport2_->SignalRtcpPacketReceived.connect( + &transport_observer2_, + &webrtc::TransportObserver::OnRtcpPacketReceived); dtls_srtp_transport2_->SignalReadyToSend.connect( - &transport_observer2_, &TransportObserver::OnReadyToSend); + &transport_observer2_, &webrtc::TransportObserver::OnReadyToSend); + webrtc::RtpDemuxerCriteria demuxer_criteria; + // 0x00 is the payload type used in kPcmuFrame. + demuxer_criteria.payload_types = {0x00}; + dtls_srtp_transport1_->RegisterRtpDemuxerSink(demuxer_criteria, + &transport_observer1_); + dtls_srtp_transport2_->RegisterRtpDemuxerSink(demuxer_criteria, + &transport_observer2_); } void CompleteDtlsHandshake(FakeDtlsTransport* fake_dtls1, @@ -251,8 +236,8 @@ class DtlsSrtpTransportTest : public testing::Test, std::unique_ptr dtls_srtp_transport1_; std::unique_ptr dtls_srtp_transport2_; - TransportObserver transport_observer1_; - TransportObserver transport_observer2_; + webrtc::TransportObserver transport_observer1_; + webrtc::TransportObserver transport_observer2_; int sequence_number_ = 0; }; diff --git a/pc/jseptransport2_unittest.cc b/pc/jseptransport2_unittest.cc index fc098ae9b6..e578e6b5ec 100644 --- a/pc/jseptransport2_unittest.cc +++ b/pc/jseptransport2_unittest.cc @@ -43,9 +43,8 @@ class JsepTransport2Test : public testing::Test, public sigslot::has_slots<> { const std::string& transport_name, rtc::PacketTransportInternal* rtp_packet_transport, rtc::PacketTransportInternal* rtcp_packet_transport) { - bool rtcp_mux_enabled = (rtcp_packet_transport == nullptr); - auto srtp_transport = - rtc::MakeUnique(rtcp_mux_enabled); + auto srtp_transport = rtc::MakeUnique( + rtcp_packet_transport == nullptr); srtp_transport->SetRtpPacketTransport(rtp_packet_transport); if (rtcp_packet_transport) { @@ -58,11 +57,8 @@ class JsepTransport2Test : public testing::Test, public sigslot::has_slots<> { const std::string& transport_name, cricket::DtlsTransportInternal* rtp_dtls_transport, cricket::DtlsTransportInternal* rtcp_dtls_transport) { - bool rtcp_mux_enabled = (rtcp_dtls_transport == nullptr); - auto srtp_transport = - rtc::MakeUnique(rtcp_mux_enabled); - auto dtls_srtp_transport = - rtc::MakeUnique(std::move(srtp_transport)); + auto dtls_srtp_transport = rtc::MakeUnique( + rtcp_dtls_transport == nullptr); dtls_srtp_transport->SetDtlsTransports(rtp_dtls_transport, rtcp_dtls_transport); diff --git a/pc/jseptransportcontroller.cc b/pc/jseptransportcontroller.cc index 5235791792..10725b5127 100644 --- a/pc/jseptransportcontroller.cc +++ b/pc/jseptransportcontroller.cc @@ -457,15 +457,12 @@ JsepTransportController::CreateDtlsSrtpTransport( cricket::DtlsTransportInternal* rtp_dtls_transport, cricket::DtlsTransportInternal* rtcp_dtls_transport) { RTC_DCHECK(network_thread_->IsCurrent()); - bool rtcp_mux_enabled = rtcp_dtls_transport == nullptr; - auto srtp_transport = - rtc::MakeUnique(rtcp_mux_enabled); - if (config_.enable_external_auth) { - srtp_transport->EnableExternalAuth(); - } - auto dtls_srtp_transport = - rtc::MakeUnique(std::move(srtp_transport)); + auto dtls_srtp_transport = rtc::MakeUnique( + rtcp_dtls_transport == nullptr); + if (config_.enable_external_auth) { + dtls_srtp_transport->EnableExternalAuth(); + } dtls_srtp_transport->SetDtlsTransports(rtp_dtls_transport, rtcp_dtls_transport); diff --git a/pc/peerconnection.cc b/pc/peerconnection.cc index 2d28b8775e..2816cb2a9c 100644 --- a/pc/peerconnection.cc +++ b/pc/peerconnection.cc @@ -6190,12 +6190,9 @@ void PeerConnection::DestroyDataChannel() { void PeerConnection::DestroyBaseChannel(cricket::BaseChannel* channel) { RTC_DCHECK(channel); - RTC_DCHECK(channel->rtp_dtls_transport()); - // Need to cache these before destroying the base channel so that we do not // access uninitialized memory. - const std::string transport_name = - channel->rtp_dtls_transport()->transport_name(); + const std::string transport_name = channel->transport_name(); const bool need_to_delete_rtcp = (channel->rtcp_dtls_transport() != nullptr); switch (channel->media_type()) { diff --git a/pc/peerconnection_media_unittest.cc b/pc/peerconnection_media_unittest.cc index 8d1dd7694f..6dabb6e939 100644 --- a/pc/peerconnection_media_unittest.cc +++ b/pc/peerconnection_media_unittest.cc @@ -959,8 +959,8 @@ void RenameContent(cricket::SessionDescription* desc, // Tests that an answer responds with the same MIDs as the offer. TEST_P(PeerConnectionMediaTest, AnswerHasSameMidsAsOffer) { - const std::string kAudioMid = "not default1"; - const std::string kVideoMid = "not default2"; + const std::string kAudioMid = "notdefault1"; + const std::string kVideoMid = "notdefault2"; auto caller = CreatePeerConnectionWithAudioVideo(); auto callee = CreatePeerConnectionWithAudioVideo(); @@ -980,8 +980,8 @@ TEST_P(PeerConnectionMediaTest, AnswerHasSameMidsAsOffer) { // Test that if the callee creates a re-offer, the MIDs are the same as the // original offer. TEST_P(PeerConnectionMediaTest, ReOfferHasSameMidsAsFirstOffer) { - const std::string kAudioMid = "not default1"; - const std::string kVideoMid = "not default2"; + const std::string kAudioMid = "notdefault1"; + const std::string kVideoMid = "notdefault2"; auto caller = CreatePeerConnectionWithAudioVideo(); auto callee = CreatePeerConnectionWithAudioVideo(); diff --git a/pc/rtptransport.cc b/pc/rtptransport.cc index 26f7e3e4c9..f59be3bc64 100644 --- a/pc/rtptransport.cc +++ b/pc/rtptransport.cc @@ -10,7 +10,10 @@ #include "pc/rtptransport.h" +#include + #include "media/base/rtputils.h" +#include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "p2p/base/p2pconstants.h" #include "p2p/base/packettransportinterface.h" #include "rtc_base/checks.h" @@ -44,7 +47,7 @@ void RtpTransport::SetRtpPacketTransport( new_packet_transport->SignalReadPacket.connect(this, &RtpTransport::OnReadPacket); new_packet_transport->SignalNetworkRouteChanged.connect( - this, &RtpTransport::OnNetworkRouteChange); + this, &RtpTransport::OnNetworkRouteChanged); new_packet_transport->SignalWritableState.connect( this, &RtpTransport::OnWritableState); new_packet_transport->SignalSentPacket.connect(this, @@ -80,7 +83,7 @@ void RtpTransport::SetRtcpPacketTransport( new_packet_transport->SignalReadPacket.connect(this, &RtpTransport::OnReadPacket); new_packet_transport->SignalNetworkRouteChanged.connect( - this, &RtpTransport::OnNetworkRouteChange); + this, &RtpTransport::OnNetworkRouteChanged); new_packet_transport->SignalWritableState.connect( this, &RtpTransport::OnWritableState); new_packet_transport->SignalSentPacket.connect(this, @@ -134,16 +137,27 @@ bool RtpTransport::SendPacket(bool rtcp, return true; } -bool RtpTransport::HandlesPacket(const uint8_t* data, size_t len) { - return bundle_filter_.DemuxPacket(data, len); +void RtpTransport::UpdateRtpHeaderExtensionMap( + const cricket::RtpHeaderExtensions& header_extensions) { + header_extension_map_ = RtpHeaderExtensionMap(header_extensions); } -bool RtpTransport::HandlesPayloadType(int payload_type) const { - return bundle_filter_.FindPayloadType(payload_type); +bool RtpTransport::RegisterRtpDemuxerSink(const RtpDemuxerCriteria& criteria, + RtpPacketSinkInterface* sink) { + rtp_demuxer_.RemoveSink(sink); + if (!rtp_demuxer_.AddSink(criteria, sink)) { + RTC_LOG(LS_ERROR) << "Failed to register the sink for RTP demuxer."; + return false; + } + return true; } -void RtpTransport::AddHandledPayloadType(int payload_type) { - bundle_filter_.AddPayloadType(payload_type); +bool RtpTransport::UnregisterRtpDemuxerSink(RtpPacketSinkInterface* sink) { + if (!rtp_demuxer_.RemoveSink(sink)) { + RTC_LOG(LS_ERROR) << "Failed to unregister the sink for RTP demuxer."; + return false; + } + return true; } PacketTransportInterface* RtpTransport::GetRtpPacketTransport() const { @@ -180,11 +194,26 @@ RtpTransportParameters RtpTransport::GetParameters() const { return parameters_; } +void RtpTransport::DemuxPacket(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& time) { + webrtc::RtpPacketReceived parsed_packet(&header_extension_map_); + if (!parsed_packet.Parse(std::move(*packet))) { + RTC_LOG(LS_ERROR) + << "Failed to parse the incoming RTP packet before demuxing. Drop it."; + return; + } + + if (time.timestamp != -1) { + parsed_packet.set_arrival_time_ms((time.timestamp + 500) / 1000); + } + rtp_demuxer_.OnRtpPacket(parsed_packet); +} + RtpTransportAdapter* RtpTransport::GetInternal() { return nullptr; } -bool RtpTransport::IsRtpTransportWritable() { +bool RtpTransport::IsTransportWritable() { auto rtcp_packet_transport = rtcp_mux_enabled_ ? nullptr : rtcp_packet_transport_; return rtp_packet_transport_ && rtp_packet_transport_->writable() && @@ -195,7 +224,7 @@ void RtpTransport::OnReadyToSend(rtc::PacketTransportInternal* transport) { SetReadyToSend(transport == rtcp_packet_transport_, true); } -void RtpTransport::OnNetworkRouteChange( +void RtpTransport::OnNetworkRouteChanged( rtc::Optional network_route) { SignalNetworkRouteChanged(network_route); } @@ -204,7 +233,7 @@ void RtpTransport::OnWritableState( rtc::PacketTransportInternal* packet_transport) { RTC_DCHECK(packet_transport == rtp_packet_transport_ || packet_transport == rtcp_packet_transport_); - SignalWritableState(IsRtpTransportWritable()); + SignalWritableState(IsTransportWritable()); } void RtpTransport::OnSentPacket(rtc::PacketTransportInternal* packet_transport, @@ -214,6 +243,44 @@ void RtpTransport::OnSentPacket(rtc::PacketTransportInternal* packet_transport, SignalSentPacket(sent_packet); } +void RtpTransport::OnRtpPacketReceived(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time) { + DemuxPacket(packet, packet_time); +} + +void RtpTransport::OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time) { + SignalRtcpPacketReceived(packet, packet_time); +} + +void RtpTransport::OnReadPacket(rtc::PacketTransportInternal* transport, + const char* data, + size_t len, + const rtc::PacketTime& packet_time, + int flags) { + TRACE_EVENT0("webrtc", "RtpTransport::OnReadPacket"); + + // When using RTCP multiplexing we might get RTCP packets on the RTP + // transport. We check the RTP payload type to determine if it is RTCP. + bool rtcp = transport == rtcp_packet_transport() || + cricket::IsRtcp(data, static_cast(len)); + rtc::CopyOnWriteBuffer packet(data, len); + + // Protect ourselves against crazy data. + if (!cricket::IsValidRtpRtcpPacketSize(rtcp, packet.size())) { + RTC_LOG(LS_ERROR) << "Dropping incoming " + << cricket::RtpRtcpStringLiteral(rtcp) + << " packet: wrong size=" << packet.size(); + return; + } + + if (rtcp) { + OnRtcpPacketReceived(&packet, packet_time); + } else { + OnRtpPacketReceived(&packet, packet_time); + } +} + void RtpTransport::SetReadyToSend(bool rtcp, bool ready) { if (rtcp) { rtcp_ready_to_send_ = ready; @@ -233,51 +300,4 @@ void RtpTransport::MaybeSignalReadyToSend() { } } -// Check the RTP payload type. If 63 < payload type < 96, it's RTCP. -// For additional details, see http://tools.ietf.org/html/rfc5761. -bool IsRtcp(const char* data, int len) { - if (len < 2) { - return false; - } - char pt = data[1] & 0x7F; - return (63 < pt) && (pt < 96); -} - -void RtpTransport::OnReadPacket(rtc::PacketTransportInternal* transport, - const char* data, - size_t len, - const rtc::PacketTime& packet_time, - int flags) { - TRACE_EVENT0("webrtc", "RtpTransport::OnReadPacket"); - - // When using RTCP multiplexing we might get RTCP packets on the RTP - // transport. We check the RTP payload type to determine if it is RTCP. - bool rtcp = transport == rtcp_packet_transport() || - IsRtcp(data, static_cast(len)); - rtc::CopyOnWriteBuffer packet(data, len); - - if (!WantsPacket(rtcp, &packet)) { - return; - } - // This mutates |packet| if it is protected. - SignalPacketReceived(rtcp, &packet, packet_time); -} - -bool RtpTransport::WantsPacket(bool rtcp, - const rtc::CopyOnWriteBuffer* packet) { - // Protect ourselves against crazy data. - if (!packet || !cricket::IsValidRtpRtcpPacketSize(rtcp, packet->size())) { - RTC_LOG(LS_ERROR) << "Dropping incoming " - << cricket::RtpRtcpStringLiteral(rtcp) - << " packet: wrong size=" << packet->size(); - return false; - } - if (rtcp) { - // Permit all (seemingly valid) RTCP packets. - return true; - } - // Check whether we handle this payload. - return HandlesPacket(packet->data(), packet->size()); -} - } // namespace webrtc diff --git a/pc/rtptransport.h b/pc/rtptransport.h index 637d447454..a857fea3dd 100644 --- a/pc/rtptransport.h +++ b/pc/rtptransport.h @@ -13,7 +13,8 @@ #include -#include "pc/bundlefilter.h" +#include "call/rtp_demuxer.h" +#include "modules/rtp_rtcp/include/rtp_header_extension_map.h" #include "pc/rtptransportinternal.h" #include "rtc_base/sigslot.h" @@ -66,9 +67,13 @@ class RtpTransport : public RtpTransportInternal { const rtc::PacketOptions& options, int flags) override; - bool HandlesPayloadType(int payload_type) const override; + void UpdateRtpHeaderExtensionMap( + const cricket::RtpHeaderExtensions& header_extensions) override; - void AddHandledPayloadType(int payload_type) override; + bool RegisterRtpDemuxerSink(const RtpDemuxerCriteria& criteria, + RtpPacketSinkInterface* sink) override; + + bool UnregisterRtpDemuxerSink(RtpPacketSinkInterface* sink) override; void SetMetricsObserver( rtc::scoped_refptr metrics_observer) override {} @@ -77,15 +82,35 @@ class RtpTransport : public RtpTransportInternal { // TODO(zstein): Remove this when we remove RtpTransportAdapter. RtpTransportAdapter* GetInternal() override; - private: - bool IsRtpTransportWritable(); - bool HandlesPacket(const uint8_t* data, size_t len); + // These methods will be used in the subclasses. + void DemuxPacket(rtc::CopyOnWriteBuffer* packet, const rtc::PacketTime& time); + bool SendPacket(bool rtcp, + rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options, + int flags); + + bool IsTransportWritable(); + + // Overridden by SrtpTransport. + virtual void OnNetworkRouteChanged( + rtc::Optional network_route); + virtual void OnRtpPacketReceived(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time); + virtual void OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time); + // Overridden by DtlsSrtpTransport. + virtual void OnWritableState(rtc::PacketTransportInternal* packet_transport); + + private: void OnReadyToSend(rtc::PacketTransportInternal* transport); - void OnNetworkRouteChange(rtc::Optional network_route); - void OnWritableState(rtc::PacketTransportInternal* packet_transport); void OnSentPacket(rtc::PacketTransportInternal* packet_transport, const rtc::SentPacket& sent_packet); + void OnReadPacket(rtc::PacketTransportInternal* transport, + const char* data, + size_t len, + const rtc::PacketTime& packet_time, + int flags); // Updates "ready to send" for an individual channel and fires // SignalReadyToSend. @@ -93,18 +118,17 @@ class RtpTransport : public RtpTransportInternal { void MaybeSignalReadyToSend(); - bool SendPacket(bool rtcp, - rtc::CopyOnWriteBuffer* packet, - const rtc::PacketOptions& options, - int flags); - - void OnReadPacket(rtc::PacketTransportInternal* transport, - const char* data, - size_t len, - const rtc::PacketTime& packet_time, - int flags); - - bool WantsPacket(bool rtcp, const rtc::CopyOnWriteBuffer* packet); + // SRTP specific methods. + // TODO(zhihuang): Improve the inheritance model so that the RtpTransport + // doesn't need to implement SRTP specfic methods. + RTCError SetSrtpSendKey(const cricket::CryptoParams& params) override { + RTC_NOTREACHED(); + return RTCError::OK(); + } + RTCError SetSrtpReceiveKey(const cricket::CryptoParams& params) override { + RTC_NOTREACHED(); + return RTCError::OK(); + } bool rtcp_mux_enabled_; @@ -116,8 +140,10 @@ class RtpTransport : public RtpTransportInternal { bool rtcp_ready_to_send_ = false; RtpTransportParameters parameters_; + RtpDemuxer rtp_demuxer_; - cricket::BundleFilter bundle_filter_; + // Used for identifying the MID for RtpDemuxer. + RtpHeaderExtensionMap header_extension_map_; }; } // namespace webrtc diff --git a/pc/rtptransport_unittest.cc b/pc/rtptransport_unittest.cc index 3876aa3998..efc2e8ccf7 100644 --- a/pc/rtptransport_unittest.cc +++ b/pc/rtptransport_unittest.cc @@ -197,49 +197,37 @@ TEST(RtpTransportTest, SetRtcpTransportWithNetworkRouteChanged) { EXPECT_FALSE(observer.network_route()); } -class SignalCounter : public sigslot::has_slots<> { - public: - explicit SignalCounter(RtpTransport* transport) { - transport->SignalReadyToSend.connect(this, &SignalCounter::OnReadyToSend); - } - int count() const { return count_; } - void OnReadyToSend(bool ready) { ++count_; } - - private: - int count_ = 0; -}; - TEST(RtpTransportTest, ChangingReadyToSendStateOnlySignalsWhenChanged) { RtpTransport transport(kMuxEnabled); - SignalCounter observer(&transport); + TransportObserver observer(&transport); rtc::FakePacketTransport fake_rtp("fake_rtp"); fake_rtp.SetWritable(true); // State changes, so we should signal. transport.SetRtpPacketTransport(&fake_rtp); - EXPECT_EQ(observer.count(), 1); + EXPECT_EQ(observer.ready_to_send_signal_count(), 1); // State does not change, so we should not signal. transport.SetRtpPacketTransport(&fake_rtp); - EXPECT_EQ(observer.count(), 1); + EXPECT_EQ(observer.ready_to_send_signal_count(), 1); // State does not change, so we should not signal. transport.SetRtcpMuxEnabled(true); - EXPECT_EQ(observer.count(), 1); + EXPECT_EQ(observer.ready_to_send_signal_count(), 1); // State changes, so we should signal. transport.SetRtcpMuxEnabled(false); - EXPECT_EQ(observer.count(), 2); + EXPECT_EQ(observer.ready_to_send_signal_count(), 2); } // Test that SignalPacketReceived fires with rtcp=true when a RTCP packet is // received. TEST(RtpTransportTest, SignalDemuxedRtcp) { RtpTransport transport(kMuxDisabled); - SignalPacketReceivedCounter observer(&transport); rtc::FakePacketTransport fake_rtp("fake_rtp"); fake_rtp.SetDestination(&fake_rtp, true); transport.SetRtpPacketTransport(&fake_rtp); + TransportObserver observer(&transport); // An rtcp packet. const char data[] = {0, 73, 0, 0}; @@ -259,11 +247,15 @@ static const int kRtpLen = 12; // handled payload type is received. TEST(RtpTransportTest, SignalHandledRtpPayloadType) { RtpTransport transport(kMuxDisabled); - SignalPacketReceivedCounter observer(&transport); rtc::FakePacketTransport fake_rtp("fake_rtp"); fake_rtp.SetDestination(&fake_rtp, true); + // Disable the encryption to allow raw RTP data. transport.SetRtpPacketTransport(&fake_rtp); - transport.AddHandledPayloadType(0x11); + TransportObserver observer(&transport); + RtpDemuxerCriteria demuxer_criteria; + // Add a handled payload type. + demuxer_criteria.payload_types = {0x11}; + transport.RegisterRtpDemuxerSink(demuxer_criteria, &observer); // An rtp packet. const rtc::PacketOptions options; @@ -272,16 +264,22 @@ TEST(RtpTransportTest, SignalHandledRtpPayloadType) { fake_rtp.SendPacket(rtp_data.data(), kRtpLen, options, flags); EXPECT_EQ(1, observer.rtp_count()); EXPECT_EQ(0, observer.rtcp_count()); + // Remove the sink before destroying the transport. + transport.UnregisterRtpDemuxerSink(&observer); } // Test that SignalPacketReceived does not fire when a RTP packet with an // unhandled payload type is received. TEST(RtpTransportTest, DontSignalUnhandledRtpPayloadType) { RtpTransport transport(kMuxDisabled); - SignalPacketReceivedCounter observer(&transport); rtc::FakePacketTransport fake_rtp("fake_rtp"); fake_rtp.SetDestination(&fake_rtp, true); transport.SetRtpPacketTransport(&fake_rtp); + TransportObserver observer(&transport); + RtpDemuxerCriteria demuxer_criteria; + // Add an unhandled payload type. + demuxer_criteria.payload_types = {0x12}; + transport.RegisterRtpDemuxerSink(demuxer_criteria, &observer); const rtc::PacketOptions options; const int flags = 0; @@ -289,6 +287,8 @@ TEST(RtpTransportTest, DontSignalUnhandledRtpPayloadType) { fake_rtp.SendPacket(rtp_data.data(), kRtpLen, options, flags); EXPECT_EQ(0, observer.rtp_count()); EXPECT_EQ(0, observer.rtcp_count()); + // Remove the sink before destroying the transport. + transport.UnregisterRtpDemuxerSink(&observer); } } // namespace webrtc diff --git a/pc/rtptransportinternal.h b/pc/rtptransportinternal.h index 0665fd76bc..4b35a57666 100644 --- a/pc/rtptransportinternal.h +++ b/pc/rtptransportinternal.h @@ -13,9 +13,11 @@ #include -#include "api/ortc/rtptransportinterface.h" +#include "api/ortc/srtptransportinterface.h" #include "api/umametrics.h" +#include "call/rtp_demuxer.h" #include "p2p/base/icetransportinternal.h" +#include "pc/sessiondescription.h" #include "rtc_base/networkroute.h" #include "rtc_base/sigslot.h" @@ -27,11 +29,11 @@ struct PacketTime; namespace webrtc { -// This represents the internal interface beneath RtpTransportInterface; +// This represents the internal interface beneath SrtpTransportInterface; // it is not accessible to API consumers but is accessible to internal classes // in order to send and receive RTP and RTCP packets belonging to a single RTP // session. Additional convenience and configuration methods are also provided. -class RtpTransportInternal : public RtpTransportInterface, +class RtpTransportInternal : public SrtpTransportInterface, public sigslot::has_slots<> { public: virtual void SetRtcpMuxEnabled(bool enable) = 0; @@ -52,11 +54,11 @@ class RtpTransportInternal : public RtpTransportInterface, // than just "writable"; it means the last send didn't return ENOTCONN. sigslot::signal1 SignalReadyToSend; - // TODO(zstein): Consider having two signals - RtpPacketReceived and - // RtcpPacketReceived. - // The first argument is true for RTCP packets and false for RTP packets. - sigslot::signal3 - SignalPacketReceived; + // Called whenever an RTCP packet is received. There is no equivalent signal + // for RTP packets because they would be forwarded to the BaseChannel through + // the RtpDemuxer callback. + sigslot::signal2 + SignalRtcpPacketReceived; // Called whenever the network route of the P2P layer transport changes. // The argument is an optional network route. @@ -80,12 +82,22 @@ class RtpTransportInternal : public RtpTransportInterface, const rtc::PacketOptions& options, int flags) = 0; - virtual bool HandlesPayloadType(int payload_type) const = 0; - - virtual void AddHandledPayloadType(int payload_type) = 0; + // This method updates the RTP header extension map so that the RTP transport + // can parse the received packets and identify the MID. This is called by the + // BaseChannel when setting the content description. + // + // Note: This doesn't take the BUNDLE case in account meaning the RTP header + // extension maps are not merged when BUNDLE is enabled. This is fine because + // the ID for MID should be consistent among all the RTP transports. + virtual void UpdateRtpHeaderExtensionMap( + const cricket::RtpHeaderExtensions& header_extensions) = 0; virtual void SetMetricsObserver( rtc::scoped_refptr metrics_observer) = 0; + virtual bool RegisterRtpDemuxerSink(const RtpDemuxerCriteria& criteria, + RtpPacketSinkInterface* sink) = 0; + + virtual bool UnregisterRtpDemuxerSink(RtpPacketSinkInterface* sink) = 0; }; } // namespace webrtc diff --git a/pc/rtptransportinternaladapter.h b/pc/rtptransportinternaladapter.h index 6a2d7e2aa7..f46b3a57b4 100644 --- a/pc/rtptransportinternaladapter.h +++ b/pc/rtptransportinternaladapter.h @@ -67,12 +67,22 @@ class RtpTransportInternalAdapter : public RtpTransportInternal { return transport_->SendRtcpPacket(packet, options, flags); } - bool HandlesPayloadType(int payload_type) const override { - return transport_->HandlesPayloadType(payload_type); + void UpdateRtpHeaderExtensionMap( + const cricket::RtpHeaderExtensions& header_extensions) override { + transport_->UpdateRtpHeaderExtensionMap(header_extensions); } - void AddHandledPayloadType(int payload_type) override { - return transport_->AddHandledPayloadType(payload_type); + void RegisterRtpDemuxerSink(const RtpDemuxerCriteria& criteria, + RtpPacketSinkInterface* sink) override { + transport_->RegisterRtpDemuxerSink(criteria, sink); + } + + void UnregisterRtpDemuxerSink(RtpPacketSinkInterface* sink) override { + transport_->UnregisterRtpDemuxerSink(sink); + } + + void SetEncryptionDisabled(bool encryption_disabled) override { + transport_->SetEncryptionDisabled(encryption_disabled); } // RtpTransportInterface overrides. diff --git a/pc/rtptransporttestutil.h b/pc/rtptransporttestutil.h index c2bdaad23e..9489fa2bfe 100644 --- a/pc/rtptransporttestutil.h +++ b/pc/rtptransporttestutil.h @@ -11,32 +11,67 @@ #ifndef PC_RTPTRANSPORTTESTUTIL_H_ #define PC_RTPTRANSPORTTESTUTIL_H_ +#include "call/rtp_packet_sink_interface.h" +#include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "pc/rtptransportinternal.h" #include "rtc_base/sigslot.h" namespace webrtc { -class SignalPacketReceivedCounter : public sigslot::has_slots<> { +// Used to handle the signals when the RtpTransport receives an RTP/RTCP packet. +// Used in Rtp/Srtp/DtlsTransport unit tests. +class TransportObserver : public RtpPacketSinkInterface, + public sigslot::has_slots<> { public: - explicit SignalPacketReceivedCounter(RtpTransportInternal* transport) { - transport->SignalPacketReceived.connect( - this, &SignalPacketReceivedCounter::OnPacketReceived); + TransportObserver() {} + + explicit TransportObserver(RtpTransportInternal* rtp_transport) { + rtp_transport->SignalRtcpPacketReceived.connect( + this, &TransportObserver::OnRtcpPacketReceived); + rtp_transport->SignalReadyToSend.connect(this, + &TransportObserver::OnReadyToSend); } - int rtcp_count() const { return rtcp_count_; } + + // RtpPacketInterface override. + void OnRtpPacket(const RtpPacketReceived& packet) override { + rtp_count_++; + last_recv_rtp_packet_ = packet.Buffer(); + } + + void OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time) { + RTC_LOG(LS_INFO) << "Received an RTCP packet."; + rtcp_count_++; + last_recv_rtcp_packet_ = *packet; + } + int rtp_count() const { return rtp_count_; } + int rtcp_count() const { return rtcp_count_; } + + rtc::CopyOnWriteBuffer last_recv_rtp_packet() { + return last_recv_rtp_packet_; + } + + rtc::CopyOnWriteBuffer last_recv_rtcp_packet() { + return last_recv_rtcp_packet_; + } + + void OnReadyToSend(bool ready) { + ready_to_send_signal_count_++; + ready_to_send_ = ready; + } + + bool ready_to_send() { return ready_to_send_; } + + int ready_to_send_signal_count() { return ready_to_send_signal_count_; } private: - void OnPacketReceived(bool rtcp, - rtc::CopyOnWriteBuffer*, - const rtc::PacketTime&) { - if (rtcp) { - ++rtcp_count_; - } else { - ++rtp_count_; - } - } - int rtcp_count_ = 0; + bool ready_to_send_ = false; int rtp_count_ = 0; + int rtcp_count_ = 0; + int ready_to_send_signal_count_ = 0; + rtc::CopyOnWriteBuffer last_recv_rtp_packet_; + rtc::CopyOnWriteBuffer last_recv_rtcp_packet_; }; } // namespace webrtc diff --git a/pc/srtptransport.cc b/pc/srtptransport.cc index 5eff3c932b..2409aee445 100644 --- a/pc/srtptransport.cc +++ b/pc/srtptransport.cc @@ -25,155 +25,133 @@ namespace webrtc { SrtpTransport::SrtpTransport(bool rtcp_mux_enabled) - : RtpTransportInternalAdapter(new RtpTransport(rtcp_mux_enabled)) { - // Own the raw pointer |transport| from the base class. - rtp_transport_.reset(transport_); - RTC_DCHECK(rtp_transport_); - ConnectToRtpTransport(); -} - -SrtpTransport::SrtpTransport( - std::unique_ptr rtp_transport) - : RtpTransportInternalAdapter(rtp_transport.get()), - rtp_transport_(std::move(rtp_transport)) { - RTC_DCHECK(rtp_transport_); - ConnectToRtpTransport(); -} - -void SrtpTransport::ConnectToRtpTransport() { - rtp_transport_->SignalPacketReceived.connect( - this, &SrtpTransport::OnPacketReceived); - rtp_transport_->SignalReadyToSend.connect(this, - &SrtpTransport::OnReadyToSend); - rtp_transport_->SignalNetworkRouteChanged.connect( - this, &SrtpTransport::OnNetworkRouteChanged); - rtp_transport_->SignalWritableState.connect(this, - &SrtpTransport::OnWritableState); - rtp_transport_->SignalSentPacket.connect(this, &SrtpTransport::OnSentPacket); -} + : RtpTransport(rtcp_mux_enabled) {} bool SrtpTransport::SendRtpPacket(rtc::CopyOnWriteBuffer* packet, const rtc::PacketOptions& options, int flags) { - return SendPacket(false, packet, options, flags); -} - -bool SrtpTransport::SendRtcpPacket(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketOptions& options, - int flags) { - return SendPacket(true, packet, options, flags); -} - -bool SrtpTransport::SendPacket(bool rtcp, - rtc::CopyOnWriteBuffer* packet, - const rtc::PacketOptions& options, - int flags) { if (!IsActive()) { RTC_LOG(LS_ERROR) << "Failed to send the packet because SRTP transport is inactive."; return false; } - rtc::PacketOptions updated_options = options; - rtc::CopyOnWriteBuffer cp = *packet; TRACE_EVENT0("webrtc", "SRTP Encode"); bool res; uint8_t* data = packet->data(); int len = static_cast(packet->size()); - if (!rtcp) { // If ENABLE_EXTERNAL_AUTH flag is on then packet authentication is not done // inside libsrtp for a RTP packet. A external HMAC module will be writing // a fake HMAC value. This is ONLY done for a RTP packet. // Socket layer will update rtp sendtime extension header if present in // packet with current time before updating the HMAC. #if !defined(ENABLE_EXTERNAL_AUTH) - res = ProtectRtp(data, len, static_cast(packet->capacity()), &len); + res = ProtectRtp(data, len, static_cast(packet->capacity()), &len); #else - if (!IsExternalAuthActive()) { - res = ProtectRtp(data, len, static_cast(packet->capacity()), &len); - } else { - updated_options.packet_time_params.rtp_sendtime_extension_id = - rtp_abs_sendtime_extn_id_; - res = ProtectRtp(data, len, static_cast(packet->capacity()), &len, - &updated_options.packet_time_params.srtp_packet_index); - // If protection succeeds, let's get auth params from srtp. + if (!IsExternalAuthActive()) { + res = ProtectRtp(data, len, static_cast(packet->capacity()), &len); + } else { + updated_options.packet_time_params.rtp_sendtime_extension_id = + rtp_abs_sendtime_extn_id_; + res = ProtectRtp(data, len, static_cast(packet->capacity()), &len, + &updated_options.packet_time_params.srtp_packet_index); + // If protection succeeds, let's get auth params from srtp. + if (res) { + uint8_t* auth_key = NULL; + int key_len; + res = GetRtpAuthParams( + &auth_key, &key_len, + &updated_options.packet_time_params.srtp_auth_tag_len); if (res) { - uint8_t* auth_key = NULL; - int key_len; - res = GetRtpAuthParams( - &auth_key, &key_len, - &updated_options.packet_time_params.srtp_auth_tag_len); - if (res) { - updated_options.packet_time_params.srtp_auth_key.resize(key_len); - updated_options.packet_time_params.srtp_auth_key.assign( - auth_key, auth_key + key_len); - } + updated_options.packet_time_params.srtp_auth_key.resize(key_len); + updated_options.packet_time_params.srtp_auth_key.assign( + auth_key, auth_key + key_len); } } + } #endif - if (!res) { - int seq_num = -1; - uint32_t ssrc = 0; - cricket::GetRtpSeqNum(data, len, &seq_num); - cricket::GetRtpSsrc(data, len, &ssrc); - RTC_LOG(LS_ERROR) << "Failed to protect RTP packet: size=" << len - << ", seqnum=" << seq_num << ", SSRC=" << ssrc; - return false; - } - } else { - res = ProtectRtcp(data, len, static_cast(packet->capacity()), &len); - if (!res) { - int type = -1; - cricket::GetRtcpType(data, len, &type); - RTC_LOG(LS_ERROR) << "Failed to protect RTCP packet: size=" << len - << ", type=" << type; - return false; - } + if (!res) { + int seq_num = -1; + uint32_t ssrc = 0; + cricket::GetRtpSeqNum(data, len, &seq_num); + cricket::GetRtpSsrc(data, len, &ssrc); + RTC_LOG(LS_ERROR) << "Failed to protect RTP packet: size=" << len + << ", seqnum=" << seq_num << ", SSRC=" << ssrc; + return false; } // Update the length of the packet now that we've added the auth tag. packet->SetSize(len); - return rtcp ? rtp_transport_->SendRtcpPacket(packet, updated_options, flags) - : rtp_transport_->SendRtpPacket(packet, updated_options, flags); + return SendPacket(/*rtcp=*/false, packet, updated_options, flags); } -void SrtpTransport::OnPacketReceived(bool rtcp, - rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time) { +bool SrtpTransport::SendRtcpPacket(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options, + int flags) { if (!IsActive()) { - RTC_LOG(LS_WARNING) - << "Inactive SRTP transport received a packet. Drop it."; - return; + RTC_LOG(LS_ERROR) + << "Failed to send the packet because SRTP transport is inactive."; + return false; } + TRACE_EVENT0("webrtc", "SRTP Encode"); + uint8_t* data = packet->data(); + int len = static_cast(packet->size()); + if (!ProtectRtcp(data, len, static_cast(packet->capacity()), &len)) { + int type = -1; + cricket::GetRtcpType(data, len, &type); + RTC_LOG(LS_ERROR) << "Failed to protect RTCP packet: size=" << len + << ", type=" << type; + return false; + } + // Update the length of the packet now that we've added the auth tag. + packet->SetSize(len); + + return SendPacket(/*rtcp=*/true, packet, options, flags); +} + +void SrtpTransport::OnRtpPacketReceived(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time) { + if (!IsActive()) { + RTC_LOG(LS_WARNING) + << "Inactive SRTP transport received an RTP packet. Drop it."; + return; + } TRACE_EVENT0("webrtc", "SRTP Decode"); char* data = packet->data(); int len = static_cast(packet->size()); - bool res; - if (!rtcp) { - res = UnprotectRtp(data, len, &len); - if (!res) { - int seq_num = -1; - uint32_t ssrc = 0; - cricket::GetRtpSeqNum(data, len, &seq_num); - cricket::GetRtpSsrc(data, len, &ssrc); - RTC_LOG(LS_ERROR) << "Failed to unprotect RTP packet: size=" << len - << ", seqnum=" << seq_num << ", SSRC=" << ssrc; - return; - } - } else { - res = UnprotectRtcp(data, len, &len); - if (!res) { - int type = -1; - cricket::GetRtcpType(data, len, &type); - RTC_LOG(LS_ERROR) << "Failed to unprotect RTCP packet: size=" << len - << ", type=" << type; - return; - } + if (!UnprotectRtp(data, len, &len)) { + int seq_num = -1; + uint32_t ssrc = 0; + cricket::GetRtpSeqNum(data, len, &seq_num); + cricket::GetRtpSsrc(data, len, &ssrc); + RTC_LOG(LS_ERROR) << "Failed to unprotect RTP packet: size=" << len + << ", seqnum=" << seq_num << ", SSRC=" << ssrc; + return; } - packet->SetSize(len); - SignalPacketReceived(rtcp, packet, packet_time); + DemuxPacket(packet, packet_time); +} + +void SrtpTransport::OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time) { + if (!IsActive()) { + RTC_LOG(LS_WARNING) + << "Inactive SRTP transport received an RTCP packet. Drop it."; + return; + } + TRACE_EVENT0("webrtc", "SRTP Decode"); + char* data = packet->data(); + int len = static_cast(packet->size()); + if (!UnprotectRtcp(data, len, &len)) { + int type = -1; + cricket::GetRtcpType(data, len, &type); + RTC_LOG(LS_ERROR) << "Failed to unprotect RTCP packet: size=" << len + << ", type=" << type; + return; + } + packet->SetSize(len); + SignalRtcpPacketReceived(packet, packet_time); } void SrtpTransport::OnNetworkRouteChanged( @@ -414,7 +392,6 @@ void SrtpTransport::SetMetricsObserver( if (recv_rtcp_session_) { recv_rtcp_session_->SetMetricsObserver(metrics_observer_); } - rtp_transport_->SetMetricsObserver(metrics_observer); } } // namespace webrtc diff --git a/pc/srtptransport.h b/pc/srtptransport.h index 818aabb2a3..24c26f7f11 100644 --- a/pc/srtptransport.h +++ b/pc/srtptransport.h @@ -17,21 +17,20 @@ #include #include "p2p/base/icetransportinternal.h" -#include "pc/rtptransportinternaladapter.h" +#include "pc/rtptransport.h" #include "pc/srtpfilter.h" #include "pc/srtpsession.h" #include "rtc_base/checks.h" namespace webrtc { -// This class will eventually be a wrapper around RtpTransportInternal -// that protects and unprotects sent and received RTP packets. -class SrtpTransport : public RtpTransportInternalAdapter { +// This subclass of the RtpTransport is used for SRTP which is reponsible for +// protecting/unprotecting the packets. It provides interfaces to set the crypto +// parameters for the SrtpSession underneath. +class SrtpTransport : public RtpTransport { public: explicit SrtpTransport(bool rtcp_mux_enabled); - explicit SrtpTransport(std::unique_ptr rtp_transport); - bool SendRtpPacket(rtc::CopyOnWriteBuffer* packet, const rtc::PacketOptions& options, int flags) override; @@ -40,6 +39,16 @@ class SrtpTransport : public RtpTransportInternalAdapter { const rtc::PacketOptions& options, int flags) override; + // SrtpTransportInterface override. + // TODO(zhihuang): Implement these methods and replace the RtpTransportAdapter + // object. + RTCError SetSrtpSendKey(const cricket::CryptoParams& params) override { + return RTCError::OK(); + } + RTCError SetSrtpReceiveKey(const cricket::CryptoParams& params) override { + return RTCError::OK(); + } + // The transport becomes active if the send_session_ and recv_session_ are // created. bool IsActive() const; @@ -105,22 +114,12 @@ class SrtpTransport : public RtpTransportInternalAdapter { void ConnectToRtpTransport(); void CreateSrtpSessions(); - bool SendPacket(bool rtcp, - rtc::CopyOnWriteBuffer* packet, - const rtc::PacketOptions& options, - int flags); - - void OnPacketReceived(bool rtcp, - rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time); - void OnReadyToSend(bool ready) { SignalReadyToSend(ready); } - void OnNetworkRouteChanged(rtc::Optional network_route); - - void OnWritableState(bool writable) { SignalWritableState(writable); } - - void OnSentPacket(const rtc::SentPacket& sent_packet) { - SignalSentPacket(sent_packet); - } + void OnRtpPacketReceived(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time) override; + void OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time) override; + void OnNetworkRouteChanged( + rtc::Optional network_route) override; bool ProtectRtp(void* data, int in_len, int max_len, int* out_len); @@ -139,7 +138,6 @@ class SrtpTransport : public RtpTransportInternalAdapter { bool UnprotectRtcp(void* data, int in_len, int* out_len); const std::string content_name_; - std::unique_ptr rtp_transport_; std::unique_ptr send_session_; std::unique_ptr recv_session_; diff --git a/pc/srtptransport_unittest.cc b/pc/srtptransport_unittest.cc index e0b2302e6f..9587e0ff50 100644 --- a/pc/srtptransport_unittest.cc +++ b/pc/srtptransport_unittest.cc @@ -42,8 +42,6 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { protected: SrtpTransportTest() { bool rtcp_mux_enabled = true; - auto rtp_transport1 = rtc::MakeUnique(rtcp_mux_enabled); - auto rtp_transport2 = rtc::MakeUnique(rtcp_mux_enabled); rtp_packet_transport1_ = rtc::MakeUnique("fake_packet_transport1"); @@ -54,38 +52,32 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { rtp_packet_transport1_->SetDestination(rtp_packet_transport2_.get(), asymmetric); - rtp_transport1->SetRtpPacketTransport(rtp_packet_transport1_.get()); - rtp_transport2->SetRtpPacketTransport(rtp_packet_transport2_.get()); + srtp_transport1_ = rtc::MakeUnique(rtcp_mux_enabled); + srtp_transport2_ = rtc::MakeUnique(rtcp_mux_enabled); - // Add payload type for RTP packet and RTCP packet. - rtp_transport1->AddHandledPayloadType(0x00); - rtp_transport2->AddHandledPayloadType(0x00); - rtp_transport1->AddHandledPayloadType(0xc9); - rtp_transport2->AddHandledPayloadType(0xc9); + srtp_transport1_->SetRtpPacketTransport(rtp_packet_transport1_.get()); + srtp_transport2_->SetRtpPacketTransport(rtp_packet_transport2_.get()); - srtp_transport1_ = - rtc::MakeUnique(std::move(rtp_transport1)); - srtp_transport2_ = - rtc::MakeUnique(std::move(rtp_transport2)); + srtp_transport1_->SignalRtcpPacketReceived.connect( + &rtp_sink1_, &TransportObserver::OnRtcpPacketReceived); + srtp_transport2_->SignalRtcpPacketReceived.connect( + &rtp_sink2_, &TransportObserver::OnRtcpPacketReceived); - srtp_transport1_->SignalPacketReceived.connect( - this, &SrtpTransportTest::OnPacketReceived1); - srtp_transport2_->SignalPacketReceived.connect( - this, &SrtpTransportTest::OnPacketReceived2); + RtpDemuxerCriteria demuxer_criteria; + // 0x00 is the payload type used in kPcmuFrame. + demuxer_criteria.payload_types = {0x00}; + + srtp_transport1_->RegisterRtpDemuxerSink(demuxer_criteria, &rtp_sink1_); + srtp_transport2_->RegisterRtpDemuxerSink(demuxer_criteria, &rtp_sink2_); } - void OnPacketReceived1(bool rtcp, - rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time) { - RTC_LOG(LS_INFO) << "SrtpTransport1 Received a packet."; - last_recv_packet1_ = *packet; - } - - void OnPacketReceived2(bool rtcp, - rtc::CopyOnWriteBuffer* packet, - const rtc::PacketTime& packet_time) { - RTC_LOG(LS_INFO) << "SrtpTransport2 Received a packet."; - last_recv_packet2_ = *packet; + ~SrtpTransportTest() { + if (srtp_transport1_) { + srtp_transport1_->UnregisterRtpDemuxerSink(&rtp_sink1_); + } + if (srtp_transport2_) { + srtp_transport2_->UnregisterRtpDemuxerSink(&rtp_sink2_); + } } // With external auth enabled, SRTP doesn't write the auth tag and @@ -142,9 +134,9 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { if (srtp_transport1_->IsExternalAuthActive()) { TestRtpAuthParams(srtp_transport1_.get(), cipher_suite_name); } else { - ASSERT_TRUE(last_recv_packet2_.data()); - EXPECT_EQ(0, - memcmp(last_recv_packet2_.data(), original_rtp_data, rtp_len)); + ASSERT_TRUE(rtp_sink2_.last_recv_rtp_packet().data()); + EXPECT_EQ(0, memcmp(rtp_sink2_.last_recv_rtp_packet().data(), + original_rtp_data, rtp_len)); // Get the encrypted packet from underneath packet transport and verify // the data is actually encrypted. auto fake_rtp_packet_transport = static_cast( @@ -159,9 +151,9 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { if (srtp_transport2_->IsExternalAuthActive()) { TestRtpAuthParams(srtp_transport2_.get(), cipher_suite_name); } else { - ASSERT_TRUE(last_recv_packet1_.data()); - EXPECT_EQ(0, - memcmp(last_recv_packet1_.data(), original_rtp_data, rtp_len)); + ASSERT_TRUE(rtp_sink1_.last_recv_rtp_packet().data()); + EXPECT_EQ(0, memcmp(rtp_sink1_.last_recv_rtp_packet().data(), + original_rtp_data, rtp_len)); auto fake_rtp_packet_transport = static_cast( srtp_transport2_->rtp_packet_transport()); EXPECT_NE(0, memcmp(fake_rtp_packet_transport->last_sent_packet()->data(), @@ -170,12 +162,12 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { } void TestSendRecvRtcpPacket(const std::string& cipher_suite_name) { - size_t rtcp_len = sizeof(kRtcpReport); + size_t rtcp_len = sizeof(::kRtcpReport); size_t packet_size = rtcp_len + 4 + rtc::rtcp_auth_tag_len(cipher_suite_name); rtc::Buffer rtcp_packet_buffer(packet_size); char* rtcp_packet_data = rtcp_packet_buffer.data(); - memcpy(rtcp_packet_data, kRtcpReport, rtcp_len); + memcpy(rtcp_packet_data, ::kRtcpReport, rtcp_len); rtc::CopyOnWriteBuffer rtcp_packet1to2(rtcp_packet_data, rtcp_len, packet_size); @@ -187,8 +179,9 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { // that the packet can be successfully received and decrypted. ASSERT_TRUE(srtp_transport1_->SendRtcpPacket(&rtcp_packet1to2, options, cricket::PF_SRTP_BYPASS)); - ASSERT_TRUE(last_recv_packet2_.data()); - EXPECT_EQ(0, memcmp(last_recv_packet2_.data(), rtcp_packet_data, rtcp_len)); + ASSERT_TRUE(rtp_sink2_.last_recv_rtcp_packet().data()); + EXPECT_EQ(0, memcmp(rtp_sink2_.last_recv_rtcp_packet().data(), + rtcp_packet_data, rtcp_len)); // Get the encrypted packet from underneath packet transport and verify the // data is actually encrypted. auto fake_rtp_packet_transport = static_cast( @@ -199,8 +192,9 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { // Do the same thing in the opposite direction; ASSERT_TRUE(srtp_transport2_->SendRtcpPacket(&rtcp_packet2to1, options, cricket::PF_SRTP_BYPASS)); - ASSERT_TRUE(last_recv_packet1_.data()); - EXPECT_EQ(0, memcmp(last_recv_packet1_.data(), rtcp_packet_data, rtcp_len)); + ASSERT_TRUE(rtp_sink1_.last_recv_rtcp_packet().data()); + EXPECT_EQ(0, memcmp(rtp_sink1_.last_recv_rtcp_packet().data(), + rtcp_packet_data, rtcp_len)); fake_rtp_packet_transport = static_cast( srtp_transport2_->rtp_packet_transport()); EXPECT_NE(0, memcmp(fake_rtp_packet_transport->last_sent_packet()->data(), @@ -267,8 +261,9 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { // that the packet can be successfully received and decrypted. ASSERT_TRUE(srtp_transport1_->SendRtpPacket(&rtp_packet1to2, options, cricket::PF_SRTP_BYPASS)); - ASSERT_TRUE(last_recv_packet2_.data()); - EXPECT_EQ(0, memcmp(last_recv_packet2_.data(), original_rtp_data, rtp_len)); + ASSERT_TRUE(rtp_sink2_.last_recv_rtp_packet().data()); + EXPECT_EQ(0, memcmp(rtp_sink2_.last_recv_rtp_packet().data(), + original_rtp_data, rtp_len)); // Get the encrypted packet from underneath packet transport and verify the // data and header extension are actually encrypted. auto fake_rtp_packet_transport = static_cast( @@ -284,8 +279,9 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { // Do the same thing in the opposite direction; ASSERT_TRUE(srtp_transport2_->SendRtpPacket(&rtp_packet2to1, options, cricket::PF_SRTP_BYPASS)); - ASSERT_TRUE(last_recv_packet1_.data()); - EXPECT_EQ(0, memcmp(last_recv_packet1_.data(), original_rtp_data, rtp_len)); + ASSERT_TRUE(rtp_sink1_.last_recv_rtp_packet().data()); + EXPECT_EQ(0, memcmp(rtp_sink1_.last_recv_rtp_packet().data(), + original_rtp_data, rtp_len)); fake_rtp_packet_transport = static_cast( srtp_transport2_->rtp_packet_transport()); EXPECT_NE(0, memcmp(fake_rtp_packet_transport->last_sent_packet()->data(), @@ -328,8 +324,9 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { std::unique_ptr rtp_packet_transport1_; std::unique_ptr rtp_packet_transport2_; - rtc::CopyOnWriteBuffer last_recv_packet1_; - rtc::CopyOnWriteBuffer last_recv_packet2_; + TransportObserver rtp_sink1_; + TransportObserver rtp_sink2_; + int sequence_number_ = 0; }; diff --git a/pc/transportcontroller.cc b/pc/transportcontroller.cc index 4e20981da6..ee6e53afd0 100644 --- a/pc/transportcontroller.cc +++ b/pc/transportcontroller.cc @@ -458,15 +458,11 @@ webrtc::DtlsSrtpTransport* TransportController::CreateDtlsSrtpTransport( return existing_rtp_transport->dtls_srtp_transport; } - auto new_srtp_transport = - rtc::MakeUnique(rtcp_mux_enabled); - -#if defined(ENABLE_EXTERNAL_AUTH) - new_srtp_transport->EnableExternalAuth(); -#endif - auto new_dtls_srtp_transport = - rtc::MakeUnique(std::move(new_srtp_transport)); + rtc::MakeUnique(rtcp_mux_enabled); +#if defined(ENABLE_EXTERNAL_AUTH) + new_dtls_srtp_transport->EnableExternalAuth(); +#endif auto rtp_dtls_transport = CreateDtlsTransport_n( transport_name, cricket::ICE_CANDIDATE_COMPONENT_RTP);