From bfd9ba8802be081a9a5201d7afb6284b3954cdff Mon Sep 17 00:00:00 2001 From: Tomas Gunnarsson Date: Sun, 18 Apr 2021 11:55:57 +0200 Subject: [PATCH] Fix unsafe variable access in RTCStatsCollector MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With this change, all production callers of BaseChannel::transport_name() will be making the call from the right thread and we can safely delegate the call to the transport itself. Some tests still need to be updated. This facilitates the main goal of not needing synchronization inside of the channel classes, being able to apply thread checks and eventually remove thread hops from the channel classes. A downside of this particular change is that a blocking call to the network thread from the signaling thread inside of RTCStatsCollector needs to be done. This is done once though and fixes a race. Bug: webrtc:12601, webrtc:11687, webrtc:12644 Change-Id: I85f34f3341a06da9a9efd936b1d36722b10ec487 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/213080 Reviewed-by: Henrik Boström Reviewed-by: Harald Alvestrand Commit-Queue: Tommi Cr-Commit-Position: refs/heads/master@{#33775} --- pc/channel.h | 11 +++- pc/peer_connection.cc | 30 +++-------- pc/peer_connection.h | 6 +-- pc/peer_connection_internal.h | 5 +- pc/rtc_stats_collector.cc | 66 ++++++++++++------------ pc/rtc_stats_collector.h | 21 +++++--- pc/stats_collector.cc | 26 ++++++++-- pc/stats_collector.h | 6 ++- pc/test/fake_peer_connection_base.h | 4 +- pc/test/fake_peer_connection_for_stats.h | 14 +---- 10 files changed, 96 insertions(+), 93 deletions(-) diff --git a/pc/channel.h b/pc/channel.h index 5799edbb54..47ffc3e764 100644 --- a/pc/channel.h +++ b/pc/channel.h @@ -124,7 +124,13 @@ class BaseChannel : public ChannelInterface, rtc::Thread* network_thread() const { return network_thread_; } const std::string& content_name() const override { return content_name_; } // TODO(deadbeef): This is redundant; remove this. - const std::string& transport_name() const override { return transport_name_; } + const std::string& transport_name() const override { + RTC_DCHECK_RUN_ON(network_thread()); + if (rtp_transport_) + return rtp_transport_->transport_name(); + // TODO(tommi): Delete this variable. + return transport_name_; + } bool enabled() const override { return enabled_; } // This function returns true if using SRTP (DTLS-based keying or SDES). @@ -332,6 +338,9 @@ class BaseChannel : public ChannelInterface, // Won't be set when using raw packet transports. SDP-specific thing. // TODO(bugs.webrtc.org/12230): Written on network thread, read on // worker thread (at least). + // TODO(tommi): Remove this variable and instead use rtp_transport_ to + // return the transport name. This variable is currently required for + // "for_test" methods. std::string transport_name_; webrtc::RtpTransportInternal* rtp_transport_ diff --git a/pc/peer_connection.cc b/pc/peer_connection.cc index 1b09cdb007..6a3e5f1014 100644 --- a/pc/peer_connection.cc +++ b/pc/peer_connection.cc @@ -2094,6 +2094,7 @@ void PeerConnection::StopRtcEventLog_w() { cricket::ChannelInterface* PeerConnection::GetChannel( const std::string& content_name) { + RTC_DCHECK_RUN_ON(network_thread()); for (const auto& transceiver : rtp_manager()->transceivers()->List()) { cricket::ChannelInterface* channel = transceiver->internal()->channel(); if (channel && channel->content_name() == content_name) { @@ -2176,6 +2177,11 @@ absl::optional PeerConnection::sctp_transport_name() const { return absl::optional(); } +absl::optional PeerConnection::sctp_mid() const { + RTC_DCHECK_RUN_ON(signaling_thread()); + return sctp_mid_s_; +} + cricket::CandidateStatsList PeerConnection::GetPooledCandidateStats() const { RTC_DCHECK_RUN_ON(network_thread()); if (!network_thread_safety_->alive()) @@ -2185,30 +2191,6 @@ cricket::CandidateStatsList PeerConnection::GetPooledCandidateStats() const { return candidate_states_list; } -std::map PeerConnection::GetTransportNamesByMid() - const { - RTC_DCHECK_RUN_ON(network_thread()); - rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; - - if (!network_thread_safety_->alive()) - return {}; - - std::map transport_names_by_mid; - for (const auto& transceiver : rtp_manager()->transceivers()->List()) { - cricket::ChannelInterface* channel = transceiver->internal()->channel(); - if (channel) { - transport_names_by_mid[channel->content_name()] = - channel->transport_name(); - } - } - if (sctp_mid_n_) { - cricket::DtlsTransportInternal* dtls_transport = - transport_controller_->GetDtlsTransport(*sctp_mid_n_); - transport_names_by_mid[*sctp_mid_n_] = dtls_transport->transport_name(); - } - return transport_names_by_mid; -} - std::map PeerConnection::GetTransportStatsByNames( const std::set& transport_names) { diff --git a/pc/peer_connection.h b/pc/peer_connection.h index b44fb87bb9..5ba9ec35b8 100644 --- a/pc/peer_connection.h +++ b/pc/peer_connection.h @@ -293,9 +293,9 @@ class PeerConnection : public PeerConnectionInternal, std::vector GetDataChannelStats() const override; absl::optional sctp_transport_name() const override; + absl::optional sctp_mid() const override; cricket::CandidateStatsList GetPooledCandidateStats() const override; - std::map GetTransportNamesByMid() const override; std::map GetTransportStatsByNames( const std::set& transport_names) override; Call::Stats GetCallStats() override; @@ -342,10 +342,6 @@ class PeerConnection : public PeerConnectionInternal, RTC_DCHECK_RUN_ON(signaling_thread()); return &configuration_; } - absl::optional sctp_mid() { - RTC_DCHECK_RUN_ON(signaling_thread()); - return sctp_mid_s_; - } PeerConnectionMessageHandler* message_handler() { RTC_DCHECK_RUN_ON(signaling_thread()); return &message_handler_; diff --git a/pc/peer_connection_internal.h b/pc/peer_connection_internal.h index d800a58fd4..6f97612914 100644 --- a/pc/peer_connection_internal.h +++ b/pc/peer_connection_internal.h @@ -50,14 +50,13 @@ class PeerConnectionInternal : public PeerConnectionInterface { } virtual absl::optional sctp_transport_name() const = 0; + virtual absl::optional sctp_mid() const = 0; virtual cricket::CandidateStatsList GetPooledCandidateStats() const = 0; - // Returns a map from MID to transport name for all active media sections. - virtual std::map GetTransportNamesByMid() const = 0; - // Returns a map from transport name to transport stats for all given // transport names. + // Must be called on the network thread. virtual std::map GetTransportStatsByNames(const std::set& transport_names) = 0; diff --git a/pc/rtc_stats_collector.cc b/pc/rtc_stats_collector.cc index 888d6389f0..93aa6af7a8 100644 --- a/pc/rtc_stats_collector.cc +++ b/pc/rtc_stats_collector.cc @@ -1197,19 +1197,19 @@ void RTCStatsCollector::GetStatsReportInternal( // Prepare |transceiver_stats_infos_| and |call_stats_| for use in // |ProducePartialResultsOnNetworkThread| and // |ProducePartialResultsOnSignalingThread|. - PrepareTransceiverStatsInfosAndCallStats_s_w(); - // Prepare |transport_names_| for use in - // |ProducePartialResultsOnNetworkThread|. - transport_names_ = PrepareTransportNames_s(); - + PrepareTransceiverStatsInfosAndCallStats_s_w_n(); // Don't touch |network_report_| on the signaling thread until // ProducePartialResultsOnNetworkThread() has signaled the // |network_report_event_|. network_report_event_.Reset(); rtc::scoped_refptr collector(this); - network_thread_->PostTask(RTC_FROM_HERE, [collector, timestamp_us] { - collector->ProducePartialResultsOnNetworkThread(timestamp_us); - }); + network_thread_->PostTask( + RTC_FROM_HERE, + [collector, sctp_transport_name = pc_->sctp_transport_name(), + timestamp_us]() mutable { + collector->ProducePartialResultsOnNetworkThread( + timestamp_us, std::move(sctp_transport_name)); + }); ProducePartialResultsOnSignalingThread(timestamp_us); } } @@ -1258,7 +1258,8 @@ void RTCStatsCollector::ProducePartialResultsOnSignalingThreadImpl( } void RTCStatsCollector::ProducePartialResultsOnNetworkThread( - int64_t timestamp_us) { + int64_t timestamp_us, + absl::optional sctp_transport_name) { RTC_DCHECK_RUN_ON(network_thread_); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; @@ -1266,8 +1267,18 @@ void RTCStatsCollector::ProducePartialResultsOnNetworkThread( // |network_report_event_| is reset before this method is invoked. network_report_ = RTCStatsReport::Create(timestamp_us); + std::set transport_names; + if (sctp_transport_name) { + transport_names.emplace(std::move(*sctp_transport_name)); + } + + for (const auto& info : transceiver_stats_infos_) { + if (info.transport_name) + transport_names.insert(*info.transport_name); + } + std::map transport_stats_by_name = - pc_->GetTransportStatsByNames(transport_names_); + pc_->GetTransportStatsByNames(transport_names); std::map transport_cert_stats = PrepareTransportCertificateStats_n(transport_stats_by_name); @@ -2027,7 +2038,7 @@ RTCStatsCollector::PrepareTransportCertificateStats_n( return transport_cert_stats; } -void RTCStatsCollector::PrepareTransceiverStatsInfosAndCallStats_s_w() { +void RTCStatsCollector::PrepareTransceiverStatsInfosAndCallStats_s_w_n() { RTC_DCHECK_RUN_ON(signaling_thread_); transceiver_stats_infos_.clear(); @@ -2040,20 +2051,26 @@ void RTCStatsCollector::PrepareTransceiverStatsInfosAndCallStats_s_w() { std::unique_ptr> video_stats; - { + auto transceivers = pc_->GetTransceiversInternal(); + + // TODO(tommi): See if we can avoid synchronously blocking the signaling + // thread while we do this (or avoid the Invoke at all). + network_thread_->Invoke(RTC_FROM_HERE, [this, &transceivers, + &voice_stats, &video_stats] { rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; - for (const auto& transceiver : pc_->GetTransceiversInternal()) { + for (const auto& transceiver_proxy : transceivers) { + RtpTransceiver* transceiver = transceiver_proxy->internal(); cricket::MediaType media_type = transceiver->media_type(); // Prepare stats entry. The TrackMediaInfoMap will be filled in after the // stats have been fetched on the worker thread. transceiver_stats_infos_.emplace_back(); RtpTransceiverStatsInfo& stats = transceiver_stats_infos_.back(); - stats.transceiver = transceiver->internal(); + stats.transceiver = transceiver; stats.media_type = media_type; - cricket::ChannelInterface* channel = transceiver->internal()->channel(); + cricket::ChannelInterface* channel = transceiver->channel(); if (!channel) { // The remaining fields require a BaseChannel. continue; @@ -2078,7 +2095,7 @@ void RTCStatsCollector::PrepareTransceiverStatsInfosAndCallStats_s_w() { RTC_NOTREACHED(); } } - } + }); // We jump to the worker thread and call GetStats() on each media channel as // well as GetCallStats(). At the same time we construct the @@ -2137,23 +2154,6 @@ void RTCStatsCollector::PrepareTransceiverStatsInfosAndCallStats_s_w() { }); } -std::set RTCStatsCollector::PrepareTransportNames_s() const { - RTC_DCHECK_RUN_ON(signaling_thread_); - rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; - - std::set transport_names; - for (const auto& transceiver : pc_->GetTransceiversInternal()) { - if (transceiver->internal()->channel()) { - transport_names.insert( - transceiver->internal()->channel()->transport_name()); - } - } - if (pc_->sctp_transport_name()) { - transport_names.insert(*pc_->sctp_transport_name()); - } - return transport_names; -} - void RTCStatsCollector::OnSctpDataChannelCreated(SctpDataChannel* channel) { channel->SignalOpened.connect(this, &RTCStatsCollector::OnDataChannelOpened); channel->SignalClosed.connect(this, &RTCStatsCollector::OnDataChannelClosed); diff --git a/pc/rtc_stats_collector.h b/pc/rtc_stats_collector.h index 624ca00f68..b5b8c8c900 100644 --- a/pc/rtc_stats_collector.h +++ b/pc/rtc_stats_collector.h @@ -227,12 +227,13 @@ class RTCStatsCollector : public virtual rtc::RefCountInterface, const std::map& transport_stats_by_name) const; // The results are stored in |transceiver_stats_infos_| and |call_stats_|. - void PrepareTransceiverStatsInfosAndCallStats_s_w(); - std::set PrepareTransportNames_s() const; + void PrepareTransceiverStatsInfosAndCallStats_s_w_n(); // Stats gathering on a particular thread. void ProducePartialResultsOnSignalingThread(int64_t timestamp_us); - void ProducePartialResultsOnNetworkThread(int64_t timestamp_us); + void ProducePartialResultsOnNetworkThread( + int64_t timestamp_us, + absl::optional sctp_transport_name); // Merges |network_report_| into |partial_report_| and completes the request. // This is a NO-OP if |network_report_| is null. void MergeNetworkReport_s(); @@ -266,12 +267,16 @@ class RTCStatsCollector : public virtual rtc::RefCountInterface, // has updated the value of |network_report_|. rtc::Event network_report_event_; - // Set in |GetStatsReport|, read in |ProducePartialResultsOnNetworkThread| and - // |ProducePartialResultsOnSignalingThread|, reset after work is complete. Not - // passed as arguments to avoid copies. This is thread safe - when we - // set/reset we know there are no pending stats requests in progress. + // Cleared and set in `PrepareTransceiverStatsInfosAndCallStats_s_w_n`, + // starting out on the signaling thread, then network. Later read on the + // network and signaling threads as part of collecting stats and finally + // reset when the work is done. Initially this variable was added and not + // passed around as an arguments to avoid copies. This is thread safe due to + // how operations are sequenced and we don't start the stats collection + // sequence if one is in progress. As a future improvement though, we could + // now get rid of the variable and keep the data scoped within a stats + // collection sequence. std::vector transceiver_stats_infos_; - std::set transport_names_; Call::Stats call_stats_; diff --git a/pc/stats_collector.cc b/pc/stats_collector.cc index 917d055eff..8955729192 100644 --- a/pc/stats_collector.cc +++ b/pc/stats_collector.cc @@ -852,20 +852,40 @@ std::map StatsCollector::ExtractSessionInfo() { RTC_DCHECK_RUN_ON(pc_->signaling_thread()); SessionStats stats; + auto transceivers = pc_->GetTransceiversInternal(); pc_->network_thread()->Invoke( - RTC_FROM_HERE, [this, &stats] { stats = ExtractSessionInfo_n(); }); + RTC_FROM_HERE, [&, sctp_transport_name = pc_->sctp_transport_name(), + sctp_mid = pc_->sctp_mid()]() mutable { + stats = ExtractSessionInfo_n( + transceivers, std::move(sctp_transport_name), std::move(sctp_mid)); + }); ExtractSessionInfo_s(stats); return std::move(stats.transport_names_by_mid); } -StatsCollector::SessionStats StatsCollector::ExtractSessionInfo_n() { +StatsCollector::SessionStats StatsCollector::ExtractSessionInfo_n( + const std::vector>>& transceivers, + absl::optional sctp_transport_name, + absl::optional sctp_mid) { RTC_DCHECK_RUN_ON(pc_->network_thread()); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; SessionStats stats; stats.candidate_stats = pc_->GetPooledCandidateStats(); - stats.transport_names_by_mid = pc_->GetTransportNamesByMid(); + for (auto& transceiver : transceivers) { + cricket::ChannelInterface* channel = transceiver->internal()->channel(); + if (channel) { + stats.transport_names_by_mid[channel->content_name()] = + channel->transport_name(); + } + } + + if (sctp_transport_name) { + RTC_DCHECK(sctp_mid); + stats.transport_names_by_mid[*sctp_mid] = *sctp_transport_name; + } std::set transport_names; for (const auto& entry : stats.transport_names_by_mid) { diff --git a/pc/stats_collector.h b/pc/stats_collector.h index eaefc438f2..2fd5d9d8f8 100644 --- a/pc/stats_collector.h +++ b/pc/stats_collector.h @@ -180,7 +180,11 @@ class StatsCollector : public StatsCollectorInterface { // Helper method to update the timestamp of track records. void UpdateTrackReports(); - SessionStats ExtractSessionInfo_n(); + SessionStats ExtractSessionInfo_n( + const std::vector>>& transceivers, + absl::optional sctp_transport_name, + absl::optional sctp_mid); void ExtractSessionInfo_s(SessionStats& session_stats); // A collection for all of our stats reports. diff --git a/pc/test/fake_peer_connection_base.h b/pc/test/fake_peer_connection_base.h index 423d86abc9..1acf86fdac 100644 --- a/pc/test/fake_peer_connection_base.h +++ b/pc/test/fake_peer_connection_base.h @@ -256,8 +256,8 @@ class FakePeerConnectionBase : public PeerConnectionInternal { return absl::nullopt; } - std::map GetTransportNamesByMid() const override { - return {}; + absl::optional sctp_mid() const override { + return absl::nullopt; } std::map GetTransportStatsByNames( diff --git a/pc/test/fake_peer_connection_for_stats.h b/pc/test/fake_peer_connection_for_stats.h index f51a69a04c..3f3e0a9ee0 100644 --- a/pc/test/fake_peer_connection_for_stats.h +++ b/pc/test/fake_peer_connection_for_stats.h @@ -328,21 +328,9 @@ class FakePeerConnectionForStats : public FakePeerConnectionBase { return {}; } - std::map GetTransportNamesByMid() const override { - std::map transport_names_by_mid; - if (voice_channel_) { - transport_names_by_mid[voice_channel_->content_name()] = - voice_channel_->transport_name(); - } - if (video_channel_) { - transport_names_by_mid[video_channel_->content_name()] = - video_channel_->transport_name(); - } - return transport_names_by_mid; - } - std::map GetTransportStatsByNames( const std::set& transport_names) override { + RTC_DCHECK_RUN_ON(network_thread_); std::map transport_stats_by_name; for (const std::string& transport_name : transport_names) { transport_stats_by_name[transport_name] =