diff --git a/media/base/rtp_utils.cc b/media/base/rtp_utils.cc index 57e719b18f..06699333b2 100644 --- a/media/base/rtp_utils.cc +++ b/media/base/rtp_utils.cc @@ -277,26 +277,22 @@ bool SetRtpHeader(void* data, size_t len, const RtpHeader& header) { SetRtpSsrc(data, len, header.ssrc)); } -bool IsRtpPacket(const void* data, size_t len) { - if (len < kMinRtpPacketLen) - return false; +static bool HasCorrectRtpVersion(rtc::ArrayView packet) { + return reinterpret_cast(packet.data())[0] >> 6 == kRtpVersion; +} - return (static_cast(data)[0] >> 6) == kRtpVersion; +bool IsRtpPacket(rtc::ArrayView packet) { + return packet.size() >= kMinRtpPacketLen && HasCorrectRtpVersion(packet); } // Check the RTP payload type. If 63 < payload type < 96, it's RTCP. // For additional details, see http://tools.ietf.org/html/rfc5761. -bool IsRtcpPacket(const char* data, size_t len) { - if (len < kMinRtcpPacketLen) { +bool IsRtcpPacket(rtc::ArrayView packet) { + if (packet.size() < kMinRtcpPacketLen || !HasCorrectRtpVersion(packet)) { return false; } - // RTCP must be a valid RTP packet. - if ((static_cast(data[0]) >> 6) != kRtpVersion) { - return false; - } - - char pt = data[1] & 0x7F; + char pt = packet[1] & 0x7F; return (63 < pt) && (pt < 96); } @@ -304,13 +300,35 @@ bool IsValidRtpPayloadType(int payload_type) { return payload_type >= 0 && payload_type <= 127; } -bool IsValidRtpRtcpPacketSize(bool rtcp, size_t size) { - return (rtcp ? size >= kMinRtcpPacketLen : size >= kMinRtpPacketLen) && - size <= kMaxRtpPacketLen; +bool IsValidRtpPacketSize(RtpPacketType packet_type, size_t size) { + // TODO(webrtc:10418): uncomment when relands. + // RTC_DCHECK_NE(RtpPacketType::kUnknown, packet_type); + size_t min_packet_length = packet_type == RtpPacketType::kRtcp + ? kMinRtcpPacketLen + : kMinRtpPacketLen; + return size >= min_packet_length && size <= kMaxRtpPacketLen; } -const char* RtpRtcpStringLiteral(bool rtcp) { - return rtcp ? "RTCP" : "RTP"; +absl::string_view RtpPacketTypeToString(RtpPacketType packet_type) { + switch (packet_type) { + case RtpPacketType::kRtp: + return "RTP"; + case RtpPacketType::kRtcp: + return "RTCP"; + case RtpPacketType::kUnknown: + return "Unknown"; + } +} + +RtpPacketType InferRtpPacketType(rtc::ArrayView packet) { + // RTCP packets are RTP packets so must check that first. + if (IsRtcpPacket(packet)) { + return RtpPacketType::kRtcp; + } + if (IsRtpPacket(packet)) { + return RtpPacketType::kRtp; + } + return RtpPacketType::kUnknown; } bool ValidateRtpHeader(const uint8_t* rtp, @@ -475,7 +493,9 @@ bool ApplyPacketOptions(uint8_t* data, } // Making sure we have a valid RTP packet at the end. - if (!IsRtpPacket(data + rtp_start_pos, rtp_length) || + auto packet = rtc::MakeArrayView( + reinterpret_cast(data + rtp_start_pos), rtp_length); + if (!IsRtpPacket(packet) || !ValidateRtpHeader(data + rtp_start_pos, rtp_length, nullptr)) { RTC_NOTREACHED(); return false; diff --git a/media/base/rtp_utils.h b/media/base/rtp_utils.h index 93f3103b9c..9ef9f9c7ba 100644 --- a/media/base/rtp_utils.h +++ b/media/base/rtp_utils.h @@ -11,6 +11,8 @@ #ifndef MEDIA_BASE_RTP_UTILS_H_ #define MEDIA_BASE_RTP_UTILS_H_ +#include "absl/strings/string_view.h" +#include "api/array_view.h" #include "rtc_base/byte_order.h" #include "rtc_base/system/rtc_export.h" @@ -41,6 +43,12 @@ enum RtcpTypes { kRtcpTypePSFB = 206, // Payload-specific Feedback message payload type. }; +enum class RtpPacketType { + kRtp, + kRtcp, + kUnknown, +}; + bool GetRtpPayloadType(const void* data, size_t len, int* value); bool GetRtpSeqNum(const void* data, size_t len, int* value); bool GetRtpTimestamp(const void* data, size_t len, uint32_t* value); @@ -54,19 +62,19 @@ bool SetRtpSsrc(void* data, size_t len, uint32_t value); // Assumes version 2, no padding, no extensions, no csrcs. bool SetRtpHeader(void* data, size_t len, const RtpHeader& header); -bool IsRtpPacket(const void* data, size_t len); +bool IsRtpPacket(rtc::ArrayView packet); -bool IsRtcpPacket(const char* data, size_t len); +bool IsRtcpPacket(rtc::ArrayView packet); +// Checks the packet header to determine if it can be an RTP or RTCP packet. +RtpPacketType InferRtpPacketType(rtc::ArrayView packet); // True if |payload type| is 0-127. bool IsValidRtpPayloadType(int payload_type); // True if |size| is appropriate for the indicated packet type. -bool IsValidRtpRtcpPacketSize(bool rtcp, size_t size); +bool IsValidRtpPacketSize(RtpPacketType packet_type, size_t size); -// TODO(zstein): Consider using an enum instead of a bool to differentiate -// between RTP and RTCP. -// Returns "RTCP" or "RTP" according to |rtcp|. -const char* RtpRtcpStringLiteral(bool rtcp); +// Returns "RTCP", "RTP" or "Unknown" according to |packet_type|. +absl::string_view RtpPacketTypeToString(RtpPacketType packet_type); // Verifies that a packet has a valid RTP header. bool RTC_EXPORT ValidateRtpHeader(const uint8_t* rtp, diff --git a/media/base/rtp_utils_unittest.cc b/media/base/rtp_utils_unittest.cc index 8ac68a4472..d88b1606dc 100644 --- a/media/base/rtp_utils_unittest.cc +++ b/media/base/rtp_utils_unittest.cc @@ -79,8 +79,18 @@ static uint8_t kRtpMsgWithAbsSendTimeExtension[] = { // Index of AbsSendTimeExtn data in message |kRtpMsgWithAbsSendTimeExtension|. static const int kAstIndexInRtpMsg = 21; +static const rtc::ArrayView kPcmuFrameArrayView = + rtc::MakeArrayView(reinterpret_cast(kPcmuFrame), + sizeof(kPcmuFrame)); +static const rtc::ArrayView kRtcpReportArrayView = + rtc::MakeArrayView(reinterpret_cast(kRtcpReport), + sizeof(kRtcpReport)); +static const rtc::ArrayView kInvalidPacketArrayView = + rtc::MakeArrayView(reinterpret_cast(kInvalidPacket), + sizeof(kInvalidPacket)); + TEST(RtpUtilsTest, GetRtp) { - EXPECT_TRUE(IsRtpPacket(kPcmuFrame, sizeof(kPcmuFrame))); + EXPECT_TRUE(IsRtpPacket(kPcmuFrameArrayView)); int pt; EXPECT_TRUE(GetRtpPayloadType(kPcmuFrame, sizeof(kPcmuFrame), &pt)); @@ -344,4 +354,11 @@ TEST(RtpUtilsTest, ApplyPacketOptionsWithAuthParamsAndAbsSendTime) { sizeof(kExpectedTimestamp))); } +TEST(RtpUtilsTest, InferRtpPacketType) { + EXPECT_EQ(RtpPacketType::kRtp, InferRtpPacketType(kPcmuFrameArrayView)); + EXPECT_EQ(RtpPacketType::kRtcp, InferRtpPacketType(kRtcpReportArrayView)); + EXPECT_EQ(RtpPacketType::kUnknown, + InferRtpPacketType(kInvalidPacketArrayView)); +} + } // namespace cricket diff --git a/pc/channel.cc b/pc/channel.cc index 991e9e338c..647663e250 100644 --- a/pc/channel.cc +++ b/pc/channel.cc @@ -93,11 +93,6 @@ static void SafeSetError(const std::string& message, std::string* error_desc) { } } -static bool ValidPacket(bool rtcp, const rtc::CopyOnWriteBuffer* packet) { - // Check the packet size. We could check the header too if needed. - return packet && IsValidRtpRtcpPacketSize(rtcp, packet->size()); -} - template void RtpParametersFromMediaDescription( const MediaContentDescriptionImpl* desc, @@ -402,6 +397,8 @@ void BaseChannel::OnTransportReadyToSend(bool ready) { bool BaseChannel::SendPacket(bool rtcp, rtc::CopyOnWriteBuffer* packet, const rtc::PacketOptions& options) { + // Until all the code is migrated to use RtpPacketType instead of bool. + RtpPacketType packet_type = rtcp ? RtpPacketType::kRtcp : RtpPacketType::kRtp; // SendPacket gets called from MediaEngine, on a pacer or an encoder thread. // If the thread is not our network thread, we will post to our network // so that the real work happens on our network. This avoids us having to @@ -430,9 +427,9 @@ bool BaseChannel::SendPacket(bool rtcp, } // Protect ourselves against crazy data. - if (!ValidPacket(rtcp, packet)) { + if (!IsValidRtpPacketSize(packet_type, packet->size())) { RTC_LOG(LS_ERROR) << "Dropping outgoing " << content_name_ << " " - << RtpRtcpStringLiteral(rtcp) + << RtpPacketTypeToString(packet_type) << " packet: wrong size=" << packet->size(); return false; } @@ -524,7 +521,9 @@ void BaseChannel::OnPacketReceived(bool rtcp, // for us to just eat packets here. This is all sidestepped if RTCP mux // is used anyway. RTC_LOG(LS_WARNING) - << "Can't process incoming " << RtpRtcpStringLiteral(rtcp) + << "Can't process incoming " + << RtpPacketTypeToString(rtcp ? RtpPacketType::kRtcp + : RtpPacketType::kRtp) << " packet when SRTP is inactive and crypto is required"; return; } diff --git a/pc/rtp_transport.cc b/pc/rtp_transport.cc index 20559e0907..bd11e57cbf 100644 --- a/pc/rtp_transport.cc +++ b/pc/rtp_transport.cc @@ -184,10 +184,10 @@ RtpTransportParameters RtpTransport::GetParameters() const { return parameters_; } -void RtpTransport::DemuxPacket(rtc::CopyOnWriteBuffer* packet, +void RtpTransport::DemuxPacket(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us) { webrtc::RtpPacketReceived parsed_packet(&header_extension_map_); - if (!parsed_packet.Parse(std::move(*packet))) { + if (!parsed_packet.Parse(std::move(packet))) { RTC_LOG(LS_ERROR) << "Failed to parse the incoming RTP packet before demuxing. Drop it."; return; @@ -233,14 +233,14 @@ void RtpTransport::OnSentPacket(rtc::PacketTransportInternal* packet_transport, SignalSentPacket(sent_packet); } -void RtpTransport::OnRtpPacketReceived(rtc::CopyOnWriteBuffer* packet, +void RtpTransport::OnRtpPacketReceived(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us) { DemuxPacket(packet, packet_time_us); } -void RtpTransport::OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, +void RtpTransport::OnRtcpPacketReceived(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us) { - SignalRtcpPacketReceived(packet, packet_time_us); + SignalRtcpPacketReceived(&packet, packet_time_us); } void RtpTransport::OnReadPacket(rtc::PacketTransportInternal* transport, @@ -252,27 +252,26 @@ void RtpTransport::OnReadPacket(rtc::PacketTransportInternal* transport, // When using RTCP multiplexing we might get RTCP packets on the RTP // transport. We check the RTP payload type to determine if it is RTCP. - bool rtcp = - transport == rtcp_packet_transport() || cricket::IsRtcpPacket(data, len); - + auto array_view = rtc::MakeArrayView(data, len); + cricket::RtpPacketType packet_type = cricket::InferRtpPacketType(array_view); // Filter out the packet that is neither RTP nor RTCP. - if (!rtcp && !cricket::IsRtpPacket(data, len)) { + if (packet_type == cricket::RtpPacketType::kUnknown) { + return; + } + + // Protect ourselves against crazy data. + if (!cricket::IsValidRtpPacketSize(packet_type, len)) { + RTC_LOG(LS_ERROR) << "Dropping incoming " + << cricket::RtpPacketTypeToString(packet_type) + << " packet: wrong size=" << len; return; } rtc::CopyOnWriteBuffer packet(data, len); - // Protect ourselves against crazy data. - if (!cricket::IsValidRtpRtcpPacketSize(rtcp, packet.size())) { - RTC_LOG(LS_ERROR) << "Dropping incoming " - << cricket::RtpRtcpStringLiteral(rtcp) - << " packet: wrong size=" << packet.size(); - return; - } - - if (rtcp) { - OnRtcpPacketReceived(&packet, packet_time_us); + if (packet_type == cricket::RtpPacketType::kRtcp) { + OnRtcpPacketReceived(std::move(packet), packet_time_us); } else { - OnRtpPacketReceived(&packet, packet_time_us); + OnRtpPacketReceived(std::move(packet), packet_time_us); } } diff --git a/pc/rtp_transport.h b/pc/rtp_transport.h index f188a17a39..dfdabbc7fe 100644 --- a/pc/rtp_transport.h +++ b/pc/rtp_transport.h @@ -87,7 +87,7 @@ class RtpTransport : public RtpTransportInternal { RtpTransportAdapter* GetInternal() override; // These methods will be used in the subclasses. - void DemuxPacket(rtc::CopyOnWriteBuffer* packet, int64_t packet_time_us); + void DemuxPacket(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us); bool SendPacket(bool rtcp, rtc::CopyOnWriteBuffer* packet, @@ -97,9 +97,9 @@ class RtpTransport : public RtpTransportInternal { // Overridden by SrtpTransport. virtual void OnNetworkRouteChanged( absl::optional network_route); - virtual void OnRtpPacketReceived(rtc::CopyOnWriteBuffer* packet, + virtual void OnRtpPacketReceived(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us); - virtual void OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, + virtual void OnRtcpPacketReceived(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us); // Overridden by SrtpTransport and DtlsSrtpTransport. virtual void OnWritableState(rtc::PacketTransportInternal* packet_transport); diff --git a/pc/srtp_transport.cc b/pc/srtp_transport.cc index c7e4f0e9e8..20e32f5a1b 100644 --- a/pc/srtp_transport.cc +++ b/pc/srtp_transport.cc @@ -13,6 +13,7 @@ #include #include #include +#include #include #include "media/base/rtp_utils.h" @@ -197,7 +198,7 @@ bool SrtpTransport::SendRtcpPacket(rtc::CopyOnWriteBuffer* packet, return SendPacket(/*rtcp=*/true, packet, options, flags); } -void SrtpTransport::OnRtpPacketReceived(rtc::CopyOnWriteBuffer* packet, +void SrtpTransport::OnRtpPacketReceived(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us) { if (!IsSrtpActive()) { RTC_LOG(LS_WARNING) @@ -205,8 +206,8 @@ void SrtpTransport::OnRtpPacketReceived(rtc::CopyOnWriteBuffer* packet, return; } TRACE_EVENT0("webrtc", "SRTP Decode"); - char* data = packet->data(); - int len = rtc::checked_cast(packet->size()); + char* data = packet.data(); + int len = rtc::checked_cast(packet.size()); if (!UnprotectRtp(data, len, &len)) { int seq_num = -1; uint32_t ssrc = 0; @@ -225,11 +226,11 @@ void SrtpTransport::OnRtpPacketReceived(rtc::CopyOnWriteBuffer* packet, ++decryption_failure_count_; return; } - packet->SetSize(len); - DemuxPacket(packet, packet_time_us); + packet.SetSize(len); + DemuxPacket(std::move(packet), packet_time_us); } -void SrtpTransport::OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, +void SrtpTransport::OnRtcpPacketReceived(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us) { if (!IsSrtpActive()) { RTC_LOG(LS_WARNING) @@ -237,8 +238,8 @@ void SrtpTransport::OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, return; } TRACE_EVENT0("webrtc", "SRTP Decode"); - char* data = packet->data(); - int len = rtc::checked_cast(packet->size()); + char* data = packet.data(); + int len = rtc::checked_cast(packet.size()); if (!UnprotectRtcp(data, len, &len)) { int type = -1; cricket::GetRtcpType(data, len, &type); @@ -246,8 +247,8 @@ void SrtpTransport::OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, << ", type=" << type; return; } - packet->SetSize(len); - SignalRtcpPacketReceived(packet, packet_time_us); + packet.SetSize(len); + SignalRtcpPacketReceived(&packet, packet_time_us); } void SrtpTransport::OnNetworkRouteChanged( diff --git a/pc/srtp_transport.h b/pc/srtp_transport.h index 75127119a8..e725733e05 100644 --- a/pc/srtp_transport.h +++ b/pc/srtp_transport.h @@ -116,9 +116,9 @@ class SrtpTransport : public RtpTransport { void ConnectToRtpTransport(); void CreateSrtpSessions(); - void OnRtpPacketReceived(rtc::CopyOnWriteBuffer* packet, + void OnRtpPacketReceived(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us) override; - void OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, + void OnRtcpPacketReceived(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us) override; void OnNetworkRouteChanged( absl::optional network_route) override;