diff --git a/api/test/fake_media_transport.h b/api/test/fake_media_transport.h index 730d4973eb..801852953d 100644 --- a/api/test/fake_media_transport.h +++ b/api/test/fake_media_transport.h @@ -14,6 +14,7 @@ #include #include #include +#include #include "absl/memory/memory.h" #include "api/media_transport_interface.h" @@ -78,9 +79,29 @@ class FakeMediaTransport : public MediaTransportInterface { } } + void AddTargetTransferRateObserver( + webrtc::TargetTransferRateObserver* observer) override { + RTC_CHECK(std::find(target_rate_observers_.begin(), + target_rate_observers_.end(), + observer) == target_rate_observers_.end()); + target_rate_observers_.push_back(observer); + } + + void RemoveTargetTransferRateObserver( + webrtc::TargetTransferRateObserver* observer) override { + auto it = std::find(target_rate_observers_.begin(), + target_rate_observers_.end(), observer); + if (it != target_rate_observers_.end()) { + target_rate_observers_.erase(it); + } + } + + int target_rate_observers_size() { return target_rate_observers_.size(); } + private: const MediaTransportSettings settings_; MediaTransportStateCallback* state_callback_; + std::vector target_rate_observers_; }; // Fake media transport factory creates fake media transport. diff --git a/call/BUILD.gn b/call/BUILD.gn index 237e507fbf..34c16efcc0 100644 --- a/call/BUILD.gn +++ b/call/BUILD.gn @@ -332,6 +332,7 @@ if (rtc_include_tests) { ":simulated_network", "..:webrtc_common", "../api:array_view", + "../api:fake_media_transport", "../api:libjingle_peerconnection_api", "../api:mock_audio_mixer", "../api/audio_codecs:builtin_audio_decoder_factory", diff --git a/call/call.cc b/call/call.cc index 1233ecdb7c..ae4525a72d 100644 --- a/call/call.cc +++ b/call/call.cc @@ -224,6 +224,15 @@ class Call final : public webrtc::Call, uint32_t allocated_without_feedback_bps, bool has_packet_feedback) override; + // This method is invoked when the media transport is created and when the + // media transport is being destructed. + // We only allow one media transport per connection. + // + // It should be called with non-null argument at most once, and if it was + // called with non-null argument, it has to be called with a null argument + // at least once after that. + void MediaTransportChange(MediaTransportInterface* media_transport) override; + private: DeliveryStatus DeliverRtcp(MediaType media_type, const uint8_t* packet, @@ -244,6 +253,10 @@ class Call final : public webrtc::Call, void UpdateHistograms(); void UpdateAggregateNetworkState(); + // If |media_transport| is not null, it registers the rate observer for the + // media transport. + void RegisterRateObserver() RTC_LOCKS_EXCLUDED(target_observer_crit_); + Clock* const clock_; const int num_cpu_cores_; @@ -362,6 +375,15 @@ class Call final : public webrtc::Call, // Declared last since it will issue callbacks from a task queue. Declaring it // last ensures that it is destroyed first and any running tasks are finished. std::unique_ptr transport_send_; + + // This is a precaution, since |MediaTransportChange| is not guaranteed to be + // invoked on a particular thread. + rtc::CriticalSection target_observer_crit_; + bool is_target_rate_observer_registered_ + RTC_GUARDED_BY(&target_observer_crit_) = false; + MediaTransportInterface* media_transport_ + RTC_GUARDED_BY(&target_observer_crit_) = nullptr; + RTC_DISALLOW_COPY_AND_ASSIGN(Call); }; } // namespace internal @@ -432,7 +454,6 @@ Call::Call(const Call::Config& config, video_send_delay_stats_(new SendDelayStats(clock_)), start_ms_(clock_->TimeInMilliseconds()) { RTC_DCHECK(config.event_log != nullptr); - transport_send->RegisterTargetTransferRateObserver(this); transport_send_ = std::move(transport_send); transport_send_ptr_ = transport_send_.get(); @@ -474,6 +495,43 @@ Call::~Call() { UpdateHistograms(); } +void Call::RegisterRateObserver() { + rtc::CritScope lock(&target_observer_crit_); + + if (is_target_rate_observer_registered_) { + return; + } + + is_target_rate_observer_registered_ = true; + + if (media_transport_) { + media_transport_->AddTargetTransferRateObserver(this); + } else { + transport_send_ptr_->RegisterTargetTransferRateObserver(this); + } +} + +void Call::MediaTransportChange(MediaTransportInterface* media_transport) { + rtc::CritScope lock(&target_observer_crit_); + + if (is_target_rate_observer_registered_) { + // Only used to unregister rate observer from media transport. Registration + // happens when the stream is created. + if (!media_transport && media_transport_) { + media_transport_->RemoveTargetTransferRateObserver(this); + media_transport_ = nullptr; + is_target_rate_observer_registered_ = false; + } + } else if (media_transport) { + RTC_DCHECK(media_transport_ == nullptr || + media_transport_ == media_transport) + << "media_transport_=" << (media_transport_ != nullptr) + << ", (media_transport_==media_transport)=" + << (media_transport_ == media_transport); + media_transport_ = media_transport; + } +} + void Call::UpdateHistograms() { RTC_HISTOGRAM_COUNTS_100000( "WebRTC.Call.LifetimeInSeconds", @@ -566,6 +624,14 @@ webrtc::AudioSendStream* Call::CreateAudioSendStream( const webrtc::AudioSendStream::Config& config) { TRACE_EVENT0("webrtc", "Call::CreateAudioSendStream"); RTC_DCHECK_CALLED_SEQUENTIALLY(&configuration_sequence_checker_); + + { + rtc::CritScope lock(&target_observer_crit_); + RTC_DCHECK(media_transport_ == config.media_transport); + } + + RegisterRateObserver(); + // Stream config is logged in AudioSendStream::ConfigureStream, as it may // change during the stream's lifetime. absl::optional suspended_rtp_state; @@ -695,6 +761,8 @@ webrtc::VideoSendStream* Call::CreateVideoSendStream( TRACE_EVENT0("webrtc", "Call::CreateVideoSendStream"); RTC_DCHECK_CALLED_SEQUENTIALLY(&configuration_sequence_checker_); + RegisterRateObserver(); + video_send_delay_stats_->AddSsrcs(config); for (size_t ssrc_index = 0; ssrc_index < config.rtp.ssrcs.size(); ++ssrc_index) { @@ -1031,6 +1099,18 @@ void Call::OnSentPacket(const rtc::SentPacket& sent_packet) { } void Call::OnTargetTransferRate(TargetTransferRate msg) { + // TODO(bugs.webrtc.org/9719) + // Call::OnTargetTransferRate requires that on target transfer rate is invoked + // from the worker queue (because bitrate_allocator_ requires it). Media + // transport does not guarantee the callback on the worker queue. + // When the threading model for MediaTransportInterface is update, reconsider + // changing this implementation. + if (!transport_send_ptr_->GetWorkerQueue()->IsCurrent()) { + transport_send_ptr_->GetWorkerQueue()->PostTask( + [this, msg] { this->OnTargetTransferRate(msg); }); + return; + } + uint32_t target_bitrate_bps = msg.target_rate.bps(); int loss_ratio_255 = msg.network_estimate.loss_rate_ratio * 255; uint8_t fraction_loss = diff --git a/call/call.h b/call/call.h index 40941e0d91..5cbbe907f7 100644 --- a/call/call.h +++ b/call/call.h @@ -58,6 +58,11 @@ class Call { virtual AudioSendStream* CreateAudioSendStream( const AudioSendStream::Config& config) = 0; + + // Gets called when media transport is created or removed. + virtual void MediaTransportChange( + MediaTransportInterface* media_transport_interface) = 0; + virtual void DestroyAudioSendStream(AudioSendStream* send_stream) = 0; virtual AudioReceiveStream* CreateAudioReceiveStream( diff --git a/call/call_unittest.cc b/call/call_unittest.cc index 83e96ff387..43c3355159 100644 --- a/call/call_unittest.cc +++ b/call/call_unittest.cc @@ -15,6 +15,7 @@ #include "absl/memory/memory.h" #include "api/audio_codecs/builtin_audio_decoder_factory.h" +#include "api/test/fake_media_transport.h" #include "api/test/mock_audio_mixer.h" #include "audio/audio_receive_stream.h" #include "audio/audio_send_stream.h" diff --git a/call/degraded_call.cc b/call/degraded_call.cc index e02a7f9c8d..a7ef41d490 100644 --- a/call/degraded_call.cc +++ b/call/degraded_call.cc @@ -215,4 +215,10 @@ PacketReceiver::DeliveryStatus DegradedCall::DeliverPacket( return status; } +void DegradedCall::MediaTransportChange( + MediaTransportInterface* media_transport) { + // TODO(bugs.webrtc.org/9719) We should add support for media transport here + // at some point. +} + } // namespace webrtc diff --git a/call/degraded_call.h b/call/degraded_call.h index ab88a51601..d78b1d1026 100644 --- a/call/degraded_call.h +++ b/call/degraded_call.h @@ -91,6 +91,7 @@ class DegradedCall : public Call, private Transport, private PacketReceiver { Clock* const clock_; const std::unique_ptr call_; + void MediaTransportChange(MediaTransportInterface* media_transport) override; const absl::optional send_config_; const std::unique_ptr send_process_thread_; SimulatedNetwork* send_simulated_network_; diff --git a/media/engine/fakewebrtccall.cc b/media/engine/fakewebrtccall.cc index 8cb4e9d192..6c5b8c74b8 100644 --- a/media/engine/fakewebrtccall.cc +++ b/media/engine/fakewebrtccall.cc @@ -644,4 +644,7 @@ void FakeCall::OnSentPacket(const rtc::SentPacket& sent_packet) { } } +void FakeCall::MediaTransportChange( + webrtc::MediaTransportInterface* media_transport_interface) {} + } // namespace cricket diff --git a/media/engine/fakewebrtccall.h b/media/engine/fakewebrtccall.h index dbcedb8f41..1b6deb0702 100644 --- a/media/engine/fakewebrtccall.h +++ b/media/engine/fakewebrtccall.h @@ -273,6 +273,9 @@ class FakeCall final : public webrtc::Call, public webrtc::PacketReceiver { int GetNumCreatedReceiveStreams() const; void SetStats(const webrtc::Call::Stats& stats); + void MediaTransportChange( + webrtc::MediaTransportInterface* media_transport_interface) override; + private: webrtc::AudioSendStream* CreateAudioSendStream( const webrtc::AudioSendStream::Config& config) override; diff --git a/pc/jseptransportcontroller.cc b/pc/jseptransportcontroller.cc index 9457ed7893..74a9ab6218 100644 --- a/pc/jseptransportcontroller.cc +++ b/pc/jseptransportcontroller.cc @@ -783,12 +783,12 @@ bool JsepTransportController::SetTransportForMid( mid_to_transport_[mid] = jsep_transport; return config_.transport_observer->OnTransportChanged( mid, jsep_transport->rtp_transport(), - jsep_transport->rtp_dtls_transport()); + jsep_transport->rtp_dtls_transport(), jsep_transport->media_transport()); } void JsepTransportController::RemoveTransportForMid(const std::string& mid) { - bool ret = - config_.transport_observer->OnTransportChanged(mid, nullptr, nullptr); + bool ret = config_.transport_observer->OnTransportChanged(mid, nullptr, + nullptr, nullptr); // Calling OnTransportChanged with nullptr should always succeed, since it is // only expected to fail when adding media to a transport (not removing). RTC_DCHECK(ret); @@ -1029,6 +1029,7 @@ RTCError JsepTransportController::MaybeCreateJsepTransport( // TODO(sukhanov): Proper error handling. RTC_CHECK(media_transport_result.ok()); + RTC_DCHECK(media_transport == nullptr); media_transport = std::move(media_transport_result.value()); } } @@ -1077,12 +1078,19 @@ void JsepTransportController::MaybeDestroyJsepTransport( return; } } + jsep_transports_by_name_.erase(mid); UpdateAggregateStates_n(); } void JsepTransportController::DestroyAllJsepTransports_n() { RTC_DCHECK(network_thread_->IsCurrent()); + + for (const auto& jsep_transport : jsep_transports_by_name_) { + config_.transport_observer->OnTransportChanged(jsep_transport.first, + nullptr, nullptr, nullptr); + } + jsep_transports_by_name_.clear(); } diff --git a/pc/jseptransportcontroller.h b/pc/jseptransportcontroller.h index 8d89795ce3..42b28c220d 100644 --- a/pc/jseptransportcontroller.h +++ b/pc/jseptransportcontroller.h @@ -57,7 +57,8 @@ class JsepTransportController : public sigslot::has_slots<> { virtual bool OnTransportChanged( const std::string& mid, RtpTransportInternal* rtp_transport, - cricket::DtlsTransportInternal* dtls_transport) = 0; + cricket::DtlsTransportInternal* dtls_transport, + MediaTransportInterface* media_transport) = 0; }; struct Config { diff --git a/pc/jseptransportcontroller_unittest.cc b/pc/jseptransportcontroller_unittest.cc index cb2023f0e2..129d22a4fc 100644 --- a/pc/jseptransportcontroller_unittest.cc +++ b/pc/jseptransportcontroller_unittest.cc @@ -11,6 +11,7 @@ #include #include +#include "api/media_transport_interface.h" #include "api/test/fake_media_transport.h" #include "p2p/base/fakedtlstransport.h" #include "p2p/base/fakeicetransport.h" @@ -298,12 +299,13 @@ class JsepTransportControllerTest : public JsepTransportController::Observer, } // JsepTransportController::Observer overrides. - bool OnTransportChanged( - const std::string& mid, - RtpTransportInternal* rtp_transport, - cricket::DtlsTransportInternal* dtls_transport) override { + bool OnTransportChanged(const std::string& mid, + RtpTransportInternal* rtp_transport, + cricket::DtlsTransportInternal* dtls_transport, + MediaTransportInterface* media_transport) override { changed_rtp_transport_by_mid_[mid] = rtp_transport; changed_dtls_transport_by_mid_[mid] = dtls_transport; + changed_media_transport_by_mid_[mid] = media_transport; return true; } @@ -328,7 +330,6 @@ class JsepTransportControllerTest : public JsepTransportController::Observer, // |network_thread_| should be destroyed after |transport_controller_| std::unique_ptr network_thread_; - std::unique_ptr transport_controller_; std::unique_ptr fake_transport_factory_; rtc::Thread* const signaling_thread_ = nullptr; bool signaled_on_non_signaling_thread_ = false; @@ -337,6 +338,12 @@ class JsepTransportControllerTest : public JsepTransportController::Observer, std::map changed_rtp_transport_by_mid_; std::map changed_dtls_transport_by_mid_; + std::map + changed_media_transport_by_mid_; + + // Transport controller needs to be destroyed first, because it may issue + // callbacks that modify the changed_*_by_mid in the destructor. + std::unique_ptr transport_controller_; }; TEST_F(JsepTransportControllerTest, GetRtpTransport) { diff --git a/pc/peerconnection.cc b/pc/peerconnection.cc index 47ecb4e1c1..982e52259e 100644 --- a/pc/peerconnection.cc +++ b/pc/peerconnection.cc @@ -6563,7 +6563,8 @@ void PeerConnection::DestroyChannelInterface( bool PeerConnection::OnTransportChanged( const std::string& mid, RtpTransportInternal* rtp_transport, - cricket::DtlsTransportInternal* dtls_transport) { + cricket::DtlsTransportInternal* dtls_transport, + MediaTransportInterface* media_transport) { bool ret = true; auto base_channel = GetChannel(mid); if (base_channel) { @@ -6572,6 +6573,9 @@ bool PeerConnection::OnTransportChanged( if (sctp_transport_ && mid == sctp_mid_) { sctp_transport_->SetDtlsTransport(dtls_transport); } + + call_->MediaTransportChange(media_transport); + return ret; } diff --git a/pc/peerconnection.h b/pc/peerconnection.h index b5ae9d2ae4..7e97afab7c 100644 --- a/pc/peerconnection.h +++ b/pc/peerconnection.h @@ -932,10 +932,10 @@ class PeerConnection : public PeerConnectionInternal, // from a session description, and the mapping from m= sections to transports // changed (as a result of BUNDLE negotiation, or m= sections being // rejected). - bool OnTransportChanged( - const std::string& mid, - RtpTransportInternal* rtp_transport, - cricket::DtlsTransportInternal* dtls_transport) override; + bool OnTransportChanged(const std::string& mid, + RtpTransportInternal* rtp_transport, + cricket::DtlsTransportInternal* dtls_transport, + MediaTransportInterface* media_transport) override; // Returns the observer. Will crash on CHECK if the observer is removed. PeerConnectionObserver* Observer() const;