diff --git a/webrtc/modules/rtp_rtcp/source/rtp_packet.cc b/webrtc/modules/rtp_rtcp/source/rtp_packet.cc index 72fd7892b2..b7d71c2789 100644 --- a/webrtc/modules/rtp_rtcp/source/rtp_packet.cc +++ b/webrtc/modules/rtp_rtcp/source/rtp_packet.cc @@ -52,26 +52,28 @@ constexpr size_t kDefaultPacketSize = 1500; // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | padding | Padding size | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +Packet::Packet() : Packet(nullptr, kDefaultPacketSize) {} + Packet::Packet(const ExtensionManager* extensions) - : extensions_(extensions), buffer_(kDefaultPacketSize) { - Clear(); -} + : Packet(extensions, kDefaultPacketSize) {} Packet::Packet(const ExtensionManager* extensions, size_t capacity) - : extensions_(extensions), buffer_(capacity) { + : buffer_(capacity) { RTC_DCHECK_GE(capacity, kFixedHeaderSize); Clear(); + if (extensions) { + IdentifyExtensions(*extensions); + } else { + for (size_t i = 0; i < kMaxExtensionHeaders; ++i) + extension_entries_[i].type = ExtensionManager::kInvalidType; + } } Packet::~Packet() {} -void Packet::IdentifyExtensions(const ExtensionManager* extensions) { - RTC_DCHECK(extensions); - extensions_ = extensions; - for (size_t i = 0; i < num_extensions_; ++i) { - uint8_t id = data()[extension_entries_[i].offset - 1] >> 4; - extension_entries_[i].type = extensions_->GetType(id); - } +void Packet::IdentifyExtensions(const ExtensionManager& extensions) { + for (size_t i = 0; i < kMaxExtensionHeaders; ++i) + extension_entries_[i].type = extensions.GetType(i + 1); } bool Packet::Parse(const uint8_t* buffer, size_t buffer_size) { @@ -211,8 +213,7 @@ void Packet::CopyHeaderFrom(const Packet& packet) { timestamp_ = packet.timestamp_; ssrc_ = packet.ssrc_; payload_offset_ = packet.payload_offset_; - num_extensions_ = packet.num_extensions_; - for (size_t i = 0; i < num_extensions_; ++i) { + for (size_t i = 0; i < kMaxExtensionHeaders; ++i) { extension_entries_[i] = packet.extension_entries_[i]; } extensions_size_ = packet.extensions_size_; @@ -253,7 +254,7 @@ void Packet::SetSsrc(uint32_t ssrc) { } void Packet::SetCsrcs(const std::vector& csrcs) { - RTC_DCHECK_EQ(num_extensions_, 0); + RTC_DCHECK_EQ(extensions_size_, 0); RTC_DCHECK_EQ(payload_size_, 0); RTC_DCHECK_EQ(padding_size_, 0); RTC_DCHECK_LE(csrcs.size(), 0x0fu); @@ -322,8 +323,11 @@ void Packet::Clear() { payload_offset_ = kFixedHeaderSize; payload_size_ = 0; padding_size_ = 0; - num_extensions_ = 0; extensions_size_ = 0; + for (ExtensionInfo& location : extension_entries_) { + location.offset = 0; + location.length = 0; + } memset(WriteAt(0), 0, kFixedHeaderSize); buffer_.SetSize(kFixedHeaderSize); @@ -362,8 +366,11 @@ bool Packet::ParseBuffer(const uint8_t* buffer, size_t size) { padding_size_ = 0; } - num_extensions_ = 0; extensions_size_ = 0; + for (ExtensionInfo& location : extension_entries_) { + location.offset = 0; + location.length = 0; + } if (has_extension) { /* RTP header extension, RFC 3550. 0 1 2 3 @@ -392,7 +399,7 @@ bool Packet::ParseBuffer(const uint8_t* buffer, size_t size) { constexpr uint8_t kPaddingId = 0; constexpr uint8_t kReservedId = 15; while (extensions_size_ + kOneByteHeaderSize < extensions_capacity) { - uint8_t id = buffer[extension_offset + extensions_size_] >> 4; + int id = buffer[extension_offset + extensions_size_] >> 4; if (id == kReservedId) { break; } else if (id == kPaddingId) { @@ -406,18 +413,16 @@ bool Packet::ParseBuffer(const uint8_t* buffer, size_t size) { LOG(LS_WARNING) << "Oversized rtp header extension."; break; } - if (num_extensions_ >= kMaxExtensionHeaders) { - LOG(LS_WARNING) << "Too many rtp header extensions."; - break; + + size_t idx = id - 1; + if (extension_entries_[idx].length != 0) { + LOG(LS_VERBOSE) << "Duplicate rtp header extension id " << id + << ". Overwriting."; } + extensions_size_ += kOneByteHeaderSize; - extension_entries_[num_extensions_].type = - extensions_ ? extensions_->GetType(id) - : ExtensionManager::kInvalidType; - extension_entries_[num_extensions_].length = length; - extension_entries_[num_extensions_].offset = - extension_offset + extensions_size_; - num_extensions_++; + extension_entries_[idx].offset = extension_offset + extensions_size_; + extension_entries_[idx].length = length; extensions_size_ += length; } } @@ -435,16 +440,19 @@ bool Packet::FindExtension(ExtensionType type, uint8_t length, uint16_t* offset) const { RTC_DCHECK(offset); - for (size_t i = 0; i < num_extensions_; ++i) { - if (extension_entries_[i].type == type) { - if (length != extension_entries_[i].length) { - LOG(LS_WARNING) << "Length mismatch for extension '" << type - << "': expected " << static_cast(length) - << ", received " - << static_cast(extension_entries_[i].length); + for (const ExtensionInfo& extension : extension_entries_) { + if (extension.type == type) { + if (extension.length == 0) { + // Extension is registered but not set. return false; } - *offset = extension_entries_[i].offset; + if (length != extension.length) { + LOG(LS_WARNING) << "Length mismatch for extension '" << type + << "': expected " << static_cast(length) + << ", received " << static_cast(extension.length); + return false; + } + *offset = extension.offset; return true; } } @@ -454,10 +462,28 @@ bool Packet::FindExtension(ExtensionType type, bool Packet::AllocateExtension(ExtensionType type, uint8_t length, uint16_t* offset) { - if (!extensions_) { - return false; + uint8_t extension_id = ExtensionManager::kInvalidId; + ExtensionInfo* extension_entry = nullptr; + for (size_t i = 0; i < kMaxExtensionHeaders; ++i) { + if (extension_entries_[i].type == type) { + extension_id = i + 1; + extension_entry = &extension_entries_[i]; + break; + } } - if (FindExtension(type, length, offset)) { + + if (!extension_entry) // Extension not registered. + return false; + + if (extension_entry->length != 0) { // Already allocated. + if (length != extension_entry->length) { + LOG(LS_WARNING) << "Length mismatch for extension '" << type + << "': expected " << static_cast(length) + << ", received " + << static_cast(extension_entry->length); + return false; + } + *offset = extension_entry->offset; return true; } @@ -469,10 +495,6 @@ bool Packet::AllocateExtension(ExtensionType type, return false; } - uint8_t extension_id = extensions_->GetId(type); - if (extension_id == ExtensionManager::kInvalidId) { - return false; - } RTC_DCHECK_GT(length, 0); RTC_DCHECK_LE(length, 16); @@ -491,9 +513,8 @@ bool Packet::AllocateExtension(ExtensionType type, (new_extensions_size + 3) / 4; // Wrap up to 32bit. // All checks passed, write down the extension. - if (num_extensions_ == 0) { + if (extensions_size_ == 0) { RTC_DCHECK_EQ(payload_offset_, kFixedHeaderSize + (num_csrc * 4)); - RTC_DCHECK_EQ(extensions_size_, 0); WriteAt(0, data()[0] | 0x10); // Set extension bit. // Profile specific ID always set to OneByteExtensionHeader. ByteWriter::WriteBigEndian(WriteAt(extensions_offset - 4), @@ -502,12 +523,10 @@ bool Packet::AllocateExtension(ExtensionType type, WriteAt(extensions_offset + extensions_size_, (extension_id << 4) | (length - 1)); - RTC_DCHECK(num_extensions_ < kMaxExtensionHeaders); - extension_entries_[num_extensions_].type = type; - extension_entries_[num_extensions_].length = length; + + extension_entry->length = length; *offset = extensions_offset + kOneByteHeaderSize + extensions_size_; - extension_entries_[num_extensions_].offset = *offset; - ++num_extensions_; + extension_entry->offset = *offset; extensions_size_ = new_extensions_size; // Update header length field. diff --git a/webrtc/modules/rtp_rtcp/source/rtp_packet.h b/webrtc/modules/rtp_rtcp/source/rtp_packet.h index 3f4d5769ba..2b3d38ec49 100644 --- a/webrtc/modules/rtp_rtcp/source/rtp_packet.h +++ b/webrtc/modules/rtp_rtcp/source/rtp_packet.h @@ -35,10 +35,8 @@ class Packet { // Parse and move given buffer into Packet. bool Parse(rtc::CopyOnWriteBuffer packet); - // Maps parsed extensions to their types to allow use of GetExtension. - // Used after parsing when |extensions| can't be provided until base rtp - // header is parsed. - void IdentifyExtensions(const ExtensionManager* extensions); + // Maps extensions id to their types. + void IdentifyExtensions(const ExtensionManager& extensions); // Header. bool Marker() const; @@ -106,6 +104,7 @@ class Packet { // packet creating and used if available in Parse function. // Adding and getting extensions will fail until |extensions| is // provided via constructor or IdentifyExtensions function. + Packet(); explicit Packet(const ExtensionManager* extensions); Packet(const Packet&) = default; Packet(const ExtensionManager* extensions, size_t capacity); @@ -144,8 +143,6 @@ class Packet { uint8_t* WriteAt(size_t offset); void WriteAt(size_t offset, uint8_t byte); - const ExtensionManager* extensions_; - // Header. bool marker_; uint8_t payload_type_; @@ -156,12 +153,9 @@ class Packet { size_t payload_offset_; // Match header size with csrcs and extensions. size_t payload_size_; - uint8_t num_extensions_ = 0; ExtensionInfo extension_entries_[kMaxExtensionHeaders]; uint16_t extensions_size_ = 0; // Unaligned. rtc::CopyOnWriteBuffer buffer_; - - Packet() = delete; }; template diff --git a/webrtc/modules/rtp_rtcp/source/rtp_packet_received.h b/webrtc/modules/rtp_rtcp/source/rtp_packet_received.h index e2222b9200..95674cf863 100644 --- a/webrtc/modules/rtp_rtcp/source/rtp_packet_received.h +++ b/webrtc/modules/rtp_rtcp/source/rtp_packet_received.h @@ -18,7 +18,7 @@ namespace webrtc { // Class to hold rtp packet with metadata for receiver side. class RtpPacketReceived : public rtp::Packet { public: - RtpPacketReceived() : Packet(nullptr) {} + RtpPacketReceived() = default; explicit RtpPacketReceived(const ExtensionManager* extensions) : Packet(extensions) {} diff --git a/webrtc/modules/rtp_rtcp/source/rtp_packet_unittest.cc b/webrtc/modules/rtp_rtcp/source/rtp_packet_unittest.cc index a84e5f87f1..fc88497b19 100644 --- a/webrtc/modules/rtp_rtcp/source/rtp_packet_unittest.cc +++ b/webrtc/modules/rtp_rtcp/source/rtp_packet_unittest.cc @@ -293,7 +293,7 @@ TEST(RtpPacketTest, ParseWithExtensionDelayed) { int32_t time_offset; EXPECT_FALSE(packet.GetExtension(&time_offset)); - packet.IdentifyExtensions(&extensions); + packet.IdentifyExtensions(extensions); EXPECT_TRUE(packet.GetExtension(&time_offset)); EXPECT_EQ(kTimeOffset, time_offset); EXPECT_EQ(0u, packet.payload_size());