diff --git a/modules/rtp_rtcp/source/rtcp_transceiver.cc b/modules/rtp_rtcp/source/rtcp_transceiver.cc index 28b95c0202..a937131e41 100644 --- a/modules/rtp_rtcp/source/rtcp_transceiver.cc +++ b/modules/rtp_rtcp/source/rtcp_transceiver.cc @@ -51,6 +51,28 @@ RtcpTransceiver::~RtcpTransceiver() { RTC_CHECK(!rtcp_transceiver_) << "Task queue is too busy to handle rtcp"; } +void RtcpTransceiver::AddMediaReceiverRtcpObserver( + uint32_t remote_ssrc, + MediaReceiverRtcpObserver* observer) { + rtc::WeakPtr ptr = ptr_; + task_queue_->PostTask([ptr, remote_ssrc, observer] { + if (ptr) + ptr->AddMediaReceiverRtcpObserver(remote_ssrc, observer); + }); +} + +void RtcpTransceiver::RemoveMediaReceiverRtcpObserver( + uint32_t remote_ssrc, + MediaReceiverRtcpObserver* observer, + std::unique_ptr on_removed) { + rtc::WeakPtr ptr = ptr_; + auto remove = [ptr, remote_ssrc, observer] { + if (ptr) + ptr->RemoveMediaReceiverRtcpObserver(remote_ssrc, observer); + }; + task_queue_->PostTaskAndReply(std::move(remove), std::move(on_removed)); +} + void RtcpTransceiver::ReceivePacket(rtc::CopyOnWriteBuffer packet) { rtc::WeakPtr ptr = ptr_; int64_t now_us = rtc::TimeMicros(); diff --git a/modules/rtp_rtcp/source/rtcp_transceiver.h b/modules/rtp_rtcp/source/rtcp_transceiver.h index 26091b57a5..8ce2db5243 100644 --- a/modules/rtp_rtcp/source/rtcp_transceiver.h +++ b/modules/rtp_rtcp/source/rtcp_transceiver.h @@ -32,6 +32,17 @@ class RtcpTransceiver { explicit RtcpTransceiver(const RtcpTransceiverConfig& config); ~RtcpTransceiver(); + // Registers observer to be notified about incoming rtcp packets. + // Calls to observer will be done on the |config.task_queue|. + void AddMediaReceiverRtcpObserver(uint32_t remote_ssrc, + MediaReceiverRtcpObserver* observer); + // Deregisters the observer. Might return before observer is deregistered. + // Posts |on_removed| task when observer is deregistered. + void RemoveMediaReceiverRtcpObserver( + uint32_t remote_ssrc, + MediaReceiverRtcpObserver* observer, + std::unique_ptr on_removed); + // Handles incoming rtcp packets. void ReceivePacket(rtc::CopyOnWriteBuffer packet); diff --git a/modules/rtp_rtcp/source/rtcp_transceiver_impl.cc b/modules/rtp_rtcp/source/rtcp_transceiver_impl.cc index 28d90bc463..2ac1c6bbc4 100644 --- a/modules/rtp_rtcp/source/rtcp_transceiver_impl.cc +++ b/modules/rtp_rtcp/source/rtcp_transceiver_impl.cc @@ -94,7 +94,7 @@ RtcpTransceiverImpl::RtcpTransceiverImpl(const RtcpTransceiverConfig& config) RtcpTransceiverImpl::~RtcpTransceiverImpl() = default; -void RtcpTransceiverImpl::AddMediaReceiverObserver( +void RtcpTransceiverImpl::AddMediaReceiverRtcpObserver( uint32_t remote_ssrc, MediaReceiverRtcpObserver* observer) { auto& stored = remote_senders_[remote_ssrc].observers; @@ -102,7 +102,7 @@ void RtcpTransceiverImpl::AddMediaReceiverObserver( stored.push_back(observer); } -void RtcpTransceiverImpl::RemoveMediaReceiverObserver( +void RtcpTransceiverImpl::RemoveMediaReceiverRtcpObserver( uint32_t remote_ssrc, MediaReceiverRtcpObserver* observer) { auto remote_sender_it = remote_senders_.find(remote_ssrc); diff --git a/modules/rtp_rtcp/source/rtcp_transceiver_impl.h b/modules/rtp_rtcp/source/rtcp_transceiver_impl.h index 35fee75ba8..abbc40fe03 100644 --- a/modules/rtp_rtcp/source/rtcp_transceiver_impl.h +++ b/modules/rtp_rtcp/source/rtcp_transceiver_impl.h @@ -38,10 +38,10 @@ class RtcpTransceiverImpl { explicit RtcpTransceiverImpl(const RtcpTransceiverConfig& config); ~RtcpTransceiverImpl(); - void AddMediaReceiverObserver(uint32_t remote_ssrc, - MediaReceiverRtcpObserver* observer); - void RemoveMediaReceiverObserver(uint32_t remote_ssrc, - MediaReceiverRtcpObserver* observer); + void AddMediaReceiverRtcpObserver(uint32_t remote_ssrc, + MediaReceiverRtcpObserver* observer); + void RemoveMediaReceiverRtcpObserver(uint32_t remote_ssrc, + MediaReceiverRtcpObserver* observer); void ReceivePacket(rtc::ArrayView packet, int64_t now_us); diff --git a/modules/rtp_rtcp/source/rtcp_transceiver_impl_unittest.cc b/modules/rtp_rtcp/source/rtcp_transceiver_impl_unittest.cc index a08eae7aa2..77d8e0cd3e 100644 --- a/modules/rtp_rtcp/source/rtcp_transceiver_impl_unittest.cc +++ b/modules/rtp_rtcp/source/rtcp_transceiver_impl_unittest.cc @@ -391,8 +391,8 @@ TEST(RtcpTransceiverImplTest, MultipleObserversOnSameSsrc) { StrictMock observer1; StrictMock observer2; RtcpTransceiverImpl rtcp_transceiver(DefaultTestConfig()); - rtcp_transceiver.AddMediaReceiverObserver(kRemoteSsrc, &observer1); - rtcp_transceiver.AddMediaReceiverObserver(kRemoteSsrc, &observer2); + rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc, &observer1); + rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc, &observer2); const NtpTime kRemoteNtp(0x9876543211); const uint32_t kRemoteRtp = 0x444555; @@ -412,14 +412,14 @@ TEST(RtcpTransceiverImplTest, DoesntCallsObserverAfterRemoved) { StrictMock observer1; StrictMock observer2; RtcpTransceiverImpl rtcp_transceiver(DefaultTestConfig()); - rtcp_transceiver.AddMediaReceiverObserver(kRemoteSsrc, &observer1); - rtcp_transceiver.AddMediaReceiverObserver(kRemoteSsrc, &observer2); + rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc, &observer1); + rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc, &observer2); SenderReport sr; sr.SetSenderSsrc(kRemoteSsrc); auto raw_packet = sr.Build(); - rtcp_transceiver.RemoveMediaReceiverObserver(kRemoteSsrc, &observer1); + rtcp_transceiver.RemoveMediaReceiverRtcpObserver(kRemoteSsrc, &observer1); EXPECT_CALL(observer1, OnSenderReport(_, _, _)).Times(0); EXPECT_CALL(observer2, OnSenderReport(_, _, _)); @@ -432,8 +432,8 @@ TEST(RtcpTransceiverImplTest, CallsObserverOnSenderReportBySenderSsrc) { StrictMock observer1; StrictMock observer2; RtcpTransceiverImpl rtcp_transceiver(DefaultTestConfig()); - rtcp_transceiver.AddMediaReceiverObserver(kRemoteSsrc1, &observer1); - rtcp_transceiver.AddMediaReceiverObserver(kRemoteSsrc2, &observer2); + rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc1, &observer1); + rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc2, &observer2); const NtpTime kRemoteNtp(0x9876543211); const uint32_t kRemoteRtp = 0x444555; @@ -454,8 +454,8 @@ TEST(RtcpTransceiverImplTest, CallsObserverOnByeBySenderSsrc) { StrictMock observer1; StrictMock observer2; RtcpTransceiverImpl rtcp_transceiver(DefaultTestConfig()); - rtcp_transceiver.AddMediaReceiverObserver(kRemoteSsrc1, &observer1); - rtcp_transceiver.AddMediaReceiverObserver(kRemoteSsrc2, &observer2); + rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc1, &observer1); + rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc2, &observer2); Bye bye; bye.SetSenderSsrc(kRemoteSsrc1); @@ -472,8 +472,8 @@ TEST(RtcpTransceiverImplTest, CallsObserverOnTargetBitrateBySenderSsrc) { StrictMock observer1; StrictMock observer2; RtcpTransceiverImpl rtcp_transceiver(DefaultTestConfig()); - rtcp_transceiver.AddMediaReceiverObserver(kRemoteSsrc1, &observer1); - rtcp_transceiver.AddMediaReceiverObserver(kRemoteSsrc2, &observer2); + rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc1, &observer1); + rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc2, &observer2); webrtc::rtcp::TargetBitrate target_bitrate; target_bitrate.AddTargetBitrate(0, 0, /*target_bitrate_kbps=*/10); @@ -499,7 +499,7 @@ TEST(RtcpTransceiverImplTest, SkipsIncorrectTargetBitrateEntries) { const uint32_t kRemoteSsrc = 12345; MockMediaReceiverRtcpObserver observer; RtcpTransceiverImpl rtcp_transceiver(DefaultTestConfig()); - rtcp_transceiver.AddMediaReceiverObserver(kRemoteSsrc, &observer); + rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc, &observer); webrtc::rtcp::TargetBitrate target_bitrate; target_bitrate.AddTargetBitrate(0, 0, /*target_bitrate_kbps=*/10); @@ -521,7 +521,7 @@ TEST(RtcpTransceiverImplTest, CallsObserverOnByeBehindSenderReport) { const uint32_t kRemoteSsrc = 12345; MockMediaReceiverRtcpObserver observer; RtcpTransceiverImpl rtcp_transceiver(DefaultTestConfig()); - rtcp_transceiver.AddMediaReceiverObserver(kRemoteSsrc, &observer); + rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc, &observer); CompoundPacket compound; SenderReport sr; @@ -541,7 +541,7 @@ TEST(RtcpTransceiverImplTest, CallsObserverOnByeBehindUnknownRtcpPacket) { const uint32_t kRemoteSsrc = 12345; MockMediaReceiverRtcpObserver observer; RtcpTransceiverImpl rtcp_transceiver(DefaultTestConfig()); - rtcp_transceiver.AddMediaReceiverObserver(kRemoteSsrc, &observer); + rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc, &observer); CompoundPacket compound; // Use Application-Defined rtcp packet as unknown. diff --git a/modules/rtp_rtcp/source/rtcp_transceiver_unittest.cc b/modules/rtp_rtcp/source/rtcp_transceiver_unittest.cc index 0a374e63b9..74986bc6cf 100644 --- a/modules/rtp_rtcp/source/rtcp_transceiver_unittest.cc +++ b/modules/rtp_rtcp/source/rtcp_transceiver_unittest.cc @@ -10,6 +10,9 @@ #include "modules/rtp_rtcp/source/rtcp_transceiver.h" +#include + +#include "modules/rtp_rtcp/source/rtcp_packet/sender_report.h" #include "modules/rtp_rtcp/source/rtcp_packet/transport_feedback.h" #include "rtc_base/event.h" #include "rtc_base/ptr_util.h" @@ -22,6 +25,7 @@ namespace { using ::testing::AtLeast; using ::testing::Invoke; using ::testing::InvokeWithoutArgs; +using ::testing::IsNull; using ::testing::NiceMock; using ::testing::_; using ::webrtc::MockTransport; @@ -29,10 +33,17 @@ using ::webrtc::RtcpTransceiver; using ::webrtc::RtcpTransceiverConfig; using ::webrtc::rtcp::TransportFeedback; +class MockMediaReceiverRtcpObserver : public webrtc::MediaReceiverRtcpObserver { + public: + MOCK_METHOD3(OnSenderReport, void(uint32_t, webrtc::NtpTime, uint32_t)); +}; + +constexpr int kTimeoutMs = 1000; + void WaitPostedTasks(rtc::TaskQueue* queue) { rtc::Event done(false, false); queue->PostTask([&done] { done.Set(); }); - ASSERT_TRUE(done.Wait(1000)); + ASSERT_TRUE(done.Wait(kTimeoutMs)); } TEST(RtcpTransceiverTest, SendsRtcpOnTaskQueueWhenCreatedOffTaskQueue) { @@ -84,6 +95,71 @@ TEST(RtcpTransceiverTest, CanBeDestoryedOnTaskQueue) { WaitPostedTasks(&queue); } +// Use rtp timestamp to distinguish different incoming sender reports. +rtc::CopyOnWriteBuffer CreateSenderReport(uint32_t ssrc, uint32_t rtp_time) { + webrtc::rtcp::SenderReport sr; + sr.SetSenderSsrc(ssrc); + sr.SetRtpTimestamp(rtp_time); + rtc::Buffer buffer = sr.Build(); + // Switch to an efficient way creating CopyOnWriteBuffer from RtcpPacket when + // there is one. Until then do not worry about extra memcpy in test. + return rtc::CopyOnWriteBuffer(buffer.data(), buffer.size()); +} + +TEST(RtcpTransceiverTest, DoesntPostToRtcpObserverAfterCallToRemove) { + const uint32_t kRemoteSsrc = 1234; + MockTransport null_transport; + rtc::TaskQueue queue("rtcp"); + RtcpTransceiverConfig config; + config.outgoing_transport = &null_transport; + config.task_queue = &queue; + RtcpTransceiver rtcp_transceiver(config); + rtc::Event observer_deleted(false, false); + + auto observer = rtc::MakeUnique(); + EXPECT_CALL(*observer, OnSenderReport(kRemoteSsrc, _, 1)); + EXPECT_CALL(*observer, OnSenderReport(kRemoteSsrc, _, 2)).Times(0); + + rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc, observer.get()); + rtcp_transceiver.ReceivePacket(CreateSenderReport(kRemoteSsrc, 1)); + rtcp_transceiver.RemoveMediaReceiverRtcpObserver( + kRemoteSsrc, observer.get(), + /*on_removed=*/rtc::NewClosure([&] { + observer.reset(); + observer_deleted.Set(); + })); + rtcp_transceiver.ReceivePacket(CreateSenderReport(kRemoteSsrc, 2)); + + EXPECT_TRUE(observer_deleted.Wait(kTimeoutMs)); + WaitPostedTasks(&queue); +} + +TEST(RtcpTransceiverTest, RemoveMediaReceiverRtcpObserverIsNonBlocking) { + const uint32_t kRemoteSsrc = 1234; + MockTransport null_transport; + rtc::TaskQueue queue("rtcp"); + RtcpTransceiverConfig config; + config.outgoing_transport = &null_transport; + config.task_queue = &queue; + RtcpTransceiver rtcp_transceiver(config); + auto observer = rtc::MakeUnique(); + rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc, observer.get()); + + rtc::Event queue_blocker(false, false); + rtc::Event observer_deleted(false, false); + queue.PostTask([&] { EXPECT_TRUE(queue_blocker.Wait(kTimeoutMs)); }); + rtcp_transceiver.RemoveMediaReceiverRtcpObserver( + kRemoteSsrc, observer.get(), + /*on_removed=*/rtc::NewClosure([&] { + observer.reset(); + observer_deleted.Set(); + })); + + EXPECT_THAT(observer, Not(IsNull())); + queue_blocker.Set(); + EXPECT_TRUE(observer_deleted.Wait(kTimeoutMs)); +} + TEST(RtcpTransceiverTest, CanCallSendCompoundPacketFromAnyThread) { MockTransport outgoing_transport; rtc::TaskQueue queue("rtcp");