diff --git a/modules/video_coding/BUILD.gn b/modules/video_coding/BUILD.gn index 3ac8751776..830701144c 100644 --- a/modules/video_coding/BUILD.gn +++ b/modules/video_coding/BUILD.gn @@ -202,6 +202,18 @@ rtc_library("timing") { absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } +rtc_library("rtt_filter") { + sources = [ + "rtt_filter.cc", + "rtt_filter.h", + ] + deps = [ "../../api/units:time_delta" ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/container:inlined_vector", + ] +} + rtc_library("video_coding") { visibility = [ "*" ] sources = [ @@ -242,8 +254,6 @@ rtc_library("video_coding") { "rtp_vp8_ref_finder.h", "rtp_vp9_ref_finder.cc", "rtp_vp9_ref_finder.h", - "rtt_filter.cc", - "rtt_filter.h", "timestamp_map.cc", "timestamp_map.h", "unique_timestamp_counter.cc", @@ -259,6 +269,7 @@ rtc_library("video_coding") { ":frame_buffer", ":frame_helpers", ":packet_buffer", + ":rtt_filter", ":timing", ":video_codec_interface", ":video_coding_utility", @@ -1084,6 +1095,7 @@ if (rtc_include_tests) { "rtp_frame_reference_finder_unittest.cc", "rtp_vp8_ref_finder_unittest.cc", "rtp_vp9_ref_finder_unittest.cc", + "rtt_filter_unittest.cc", "session_info_unittest.cc", "test/stream_generator.cc", "test/stream_generator.h", @@ -1119,6 +1131,7 @@ if (rtc_include_tests) { ":h264_packet_buffer", ":nack_requester", ":packet_buffer", + ":rtt_filter", ":simulcast_test_fixture_impl", ":timing", ":video_codec_interface", diff --git a/modules/video_coding/jitter_estimator.cc b/modules/video_coding/jitter_estimator.cc index 38ce35b61e..cf57232d0f 100644 --- a/modules/video_coding/jitter_estimator.cc +++ b/modules/video_coding/jitter_estimator.cc @@ -346,7 +346,7 @@ void VCMJitterEstimator::PostProcessEstimate() { } void VCMJitterEstimator::UpdateRtt(int64_t rttMs) { - _rttFilter.Update(rttMs); + _rttFilter.Update(TimeDelta::Millis(rttMs)); } // Returns the current filtered estimate if available, @@ -364,10 +364,10 @@ int VCMJitterEstimator::GetJitterEstimate( jitterMS = _filterJitterEstimate; if (_nackCount >= _nackLimit) { if (rttMultAddCapMs.has_value()) { - jitterMS += - std::min(_rttFilter.RttMs() * rttMultiplier, rttMultAddCapMs.value()); + jitterMS += std::min(_rttFilter.Rtt().ms() * rttMultiplier, + rttMultAddCapMs.value()); } else { - jitterMS += _rttFilter.RttMs() * rttMultiplier; + jitterMS += _rttFilter.Rtt().ms() * rttMultiplier; } } diff --git a/modules/video_coding/rtt_filter.cc b/modules/video_coding/rtt_filter.cc index 773ff6867e..eaf3b2b301 100644 --- a/modules/video_coding/rtt_filter.cc +++ b/modules/video_coding/rtt_filter.cc @@ -14,137 +14,148 @@ #include #include -#include "modules/video_coding/internal_defines.h" +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "api/units/time_delta.h" namespace webrtc { +namespace { + +constexpr TimeDelta kMaxRtt = TimeDelta::Seconds(3); +constexpr uint32_t kFilterFactorMax = 35; +constexpr double kJumpStddev = 2.5; +constexpr double kDriftStdDev = 3.5; + +} // namespace + VCMRttFilter::VCMRttFilter() - : _filtFactMax(35), - _jumpStdDevs(2.5), - _driftStdDevs(3.5), - _detectThreshold(kMaxDriftJumpCount) { + : avg_rtt_(TimeDelta::Zero()), + var_rtt_(0), + max_rtt_(TimeDelta::Zero()), + jump_buf_(kMaxDriftJumpCount, TimeDelta::Zero()), + drift_buf_(kMaxDriftJumpCount, TimeDelta::Zero()) { Reset(); } void VCMRttFilter::Reset() { - _gotNonZeroUpdate = false; - _avgRtt = 0; - _varRtt = 0; - _maxRtt = 0; - _filtFactCount = 1; - _jumpCount = 0; - _driftCount = 0; - memset(_jumpBuf, 0, sizeof(_jumpBuf)); - memset(_driftBuf, 0, sizeof(_driftBuf)); + got_non_zero_update_ = false; + avg_rtt_ = TimeDelta::Zero(); + var_rtt_ = 0; + max_rtt_ = TimeDelta::Zero(); + filt_fact_count_ = 1; + absl::c_fill(jump_buf_, TimeDelta::Zero()); + absl::c_fill(drift_buf_, TimeDelta::Zero()); } -void VCMRttFilter::Update(int64_t rttMs) { - if (!_gotNonZeroUpdate) { - if (rttMs == 0) { +void VCMRttFilter::Update(TimeDelta rtt) { + if (!got_non_zero_update_) { + if (rtt.IsZero()) { return; } - _gotNonZeroUpdate = true; + got_non_zero_update_ = true; } // Sanity check - if (rttMs > 3000) { - rttMs = 3000; + if (rtt > kMaxRtt) { + rtt = kMaxRtt; } - double filtFactor = 0; - if (_filtFactCount > 1) { - filtFactor = static_cast(_filtFactCount - 1) / _filtFactCount; + double filt_factor = 0; + if (filt_fact_count_ > 1) { + filt_factor = static_cast(filt_fact_count_ - 1) / filt_fact_count_; } - _filtFactCount++; - if (_filtFactCount > _filtFactMax) { - // This prevents filtFactor from going above - // (_filtFactMax - 1) / _filtFactMax, - // e.g., _filtFactMax = 50 => filtFactor = 49/50 = 0.98 - _filtFactCount = _filtFactMax; + filt_fact_count_++; + if (filt_fact_count_ > kFilterFactorMax) { + // This prevents filt_factor from going above + // (_filt_fact_max - 1) / filt_fact_max_, + // e.g., filt_fact_max_ = 50 => filt_factor = 49/50 = 0.98 + filt_fact_count_ = kFilterFactorMax; } - double oldAvg = _avgRtt; - double oldVar = _varRtt; - _avgRtt = filtFactor * _avgRtt + (1 - filtFactor) * rttMs; - _varRtt = filtFactor * _varRtt + - (1 - filtFactor) * (rttMs - _avgRtt) * (rttMs - _avgRtt); - _maxRtt = VCM_MAX(rttMs, _maxRtt); - if (!JumpDetection(rttMs) || !DriftDetection(rttMs)) { + TimeDelta old_avg = avg_rtt_; + int64_t old_var = var_rtt_; + avg_rtt_ = filt_factor * avg_rtt_ + (1 - filt_factor) * rtt; + int64_t delta_ms = (rtt - avg_rtt_).ms(); + var_rtt_ = filt_factor * var_rtt_ + (1 - filt_factor) * (delta_ms * delta_ms); + max_rtt_ = std::max(rtt, max_rtt_); + if (!JumpDetection(rtt) || !DriftDetection(rtt)) { // In some cases we don't want to update the statistics - _avgRtt = oldAvg; - _varRtt = oldVar; + avg_rtt_ = old_avg; + var_rtt_ = old_var; } } -bool VCMRttFilter::JumpDetection(int64_t rttMs) { - double diffFromAvg = _avgRtt - rttMs; - if (fabs(diffFromAvg) > _jumpStdDevs * sqrt(_varRtt)) { - int diffSign = (diffFromAvg >= 0) ? 1 : -1; - int jumpCountSign = (_jumpCount >= 0) ? 1 : -1; - if (diffSign != jumpCountSign) { +bool VCMRttFilter::JumpDetection(TimeDelta rtt) { + TimeDelta diff_from_avg = avg_rtt_ - rtt; + // Unit of var_rtt_ is ms^2. + TimeDelta jump_threshold = TimeDelta::Millis(kJumpStddev * sqrt(var_rtt_)); + if (diff_from_avg.Abs() > jump_threshold) { + bool positive_diff = diff_from_avg >= TimeDelta::Zero(); + if (!jump_buf_.empty() && positive_diff != last_jump_positive_) { // Since the signs differ the samples currently // in the buffer is useless as they represent a // jump in a different direction. - _jumpCount = 0; + jump_buf_.clear(); } - if (abs(_jumpCount) < kMaxDriftJumpCount) { - // Update the buffer used for the short time - // statistics. + if (jump_buf_.size() < kMaxDriftJumpCount) { + // Update the buffer used for the short time statistics. // The sign of the diff is used for updating the counter since // we want to use the same buffer for keeping track of when // the RTT jumps down and up. - _jumpBuf[abs(_jumpCount)] = rttMs; - _jumpCount += diffSign; + jump_buf_.push_back(rtt); + last_jump_positive_ = positive_diff; } - if (abs(_jumpCount) >= _detectThreshold) { + if (jump_buf_.size() >= kMaxDriftJumpCount) { // Detected an RTT jump - ShortRttFilter(_jumpBuf, abs(_jumpCount)); - _filtFactCount = _detectThreshold + 1; - _jumpCount = 0; + ShortRttFilter(jump_buf_); + filt_fact_count_ = kMaxDriftJumpCount + 1; + jump_buf_.clear(); } else { return false; } } else { - _jumpCount = 0; + jump_buf_.clear(); } return true; } -bool VCMRttFilter::DriftDetection(int64_t rttMs) { - if (_maxRtt - _avgRtt > _driftStdDevs * sqrt(_varRtt)) { - if (_driftCount < kMaxDriftJumpCount) { - // Update the buffer used for the short time - // statistics. - _driftBuf[_driftCount] = rttMs; - _driftCount++; +bool VCMRttFilter::DriftDetection(TimeDelta rtt) { + // Unit of sqrt of var_rtt_ is ms. + TimeDelta drift_threshold = TimeDelta::Millis(kDriftStdDev * sqrt(var_rtt_)); + if (max_rtt_ - avg_rtt_ > drift_threshold) { + if (drift_buf_.size() < kMaxDriftJumpCount) { + // Update the buffer used for the short time statistics. + drift_buf_.push_back(rtt); } - if (_driftCount >= _detectThreshold) { + if (drift_buf_.size() >= kMaxDriftJumpCount) { // Detected an RTT drift - ShortRttFilter(_driftBuf, _driftCount); - _filtFactCount = _detectThreshold + 1; - _driftCount = 0; + ShortRttFilter(drift_buf_); + filt_fact_count_ = kMaxDriftJumpCount + 1; + drift_buf_.clear(); } } else { - _driftCount = 0; + drift_buf_.clear(); } return true; } -void VCMRttFilter::ShortRttFilter(int64_t* buf, uint32_t length) { - if (length == 0) { - return; - } - _maxRtt = 0; - _avgRtt = 0; - for (uint32_t i = 0; i < length; i++) { - if (buf[i] > _maxRtt) { - _maxRtt = buf[i]; +void VCMRttFilter::ShortRttFilter(const BufferList& buf) { + RTC_DCHECK_EQ(buf.size(), kMaxDriftJumpCount); + max_rtt_ = TimeDelta::Zero(); + avg_rtt_ = TimeDelta::Zero(); + for (const TimeDelta& rtt : buf) { + if (rtt > max_rtt_) { + max_rtt_ = rtt; } - _avgRtt += buf[i]; + avg_rtt_ += rtt; } - _avgRtt = _avgRtt / static_cast(length); + avg_rtt_ = avg_rtt_ / static_cast(buf.size()); } -int64_t VCMRttFilter::RttMs() const { - return static_cast(_maxRtt + 0.5); +TimeDelta VCMRttFilter::Rtt() const { + return max_rtt_; } + } // namespace webrtc diff --git a/modules/video_coding/rtt_filter.h b/modules/video_coding/rtt_filter.h index bc4f56d2b9..a611aafcb8 100644 --- a/modules/video_coding/rtt_filter.h +++ b/modules/video_coding/rtt_filter.h @@ -13,6 +13,9 @@ #include +#include "absl/container/inlined_vector.h" +#include "api/units/time_delta.h" + namespace webrtc { class VCMRttFilter { @@ -24,41 +27,41 @@ class VCMRttFilter { // Resets the filter. void Reset(); // Updates the filter with a new sample. - void Update(int64_t rttMs); - // A getter function for the current RTT level in ms. - int64_t RttMs() const; + void Update(TimeDelta rtt); + // A getter function for the current RTT level. + TimeDelta Rtt() const; private: // The size of the drift and jump memory buffers // and thus also the detection threshold for these // detectors in number of samples. - enum { kMaxDriftJumpCount = 5 }; + static constexpr int kMaxDriftJumpCount = 5; + using BufferList = absl::InlinedVector; + // Detects RTT jumps by comparing the difference between // samples and average to the standard deviation. // Returns true if the long time statistics should be updated // and false otherwise - bool JumpDetection(int64_t rttMs); + bool JumpDetection(TimeDelta rtt); + // Detects RTT drifts by comparing the difference between // max and average to the standard deviation. // Returns true if the long time statistics should be updated // and false otherwise - bool DriftDetection(int64_t rttMs); - // Computes the short time average and maximum of the vector buf. - void ShortRttFilter(int64_t* buf, uint32_t length); + bool DriftDetection(TimeDelta rtt); - bool _gotNonZeroUpdate; - double _avgRtt; - double _varRtt; - int64_t _maxRtt; - uint32_t _filtFactCount; - const uint32_t _filtFactMax; - const double _jumpStdDevs; - const double _driftStdDevs; - int32_t _jumpCount; - int32_t _driftCount; - const int32_t _detectThreshold; - int64_t _jumpBuf[kMaxDriftJumpCount]; - int64_t _driftBuf[kMaxDriftJumpCount]; + // Computes the short time average and maximum of the vector buf. + void ShortRttFilter(const BufferList& buf); + + bool got_non_zero_update_; + TimeDelta avg_rtt_; + // Variance units are TimeDelta^2. Store as ms^2. + int64_t var_rtt_; + TimeDelta max_rtt_; + uint32_t filt_fact_count_; + bool last_jump_positive_ = false; + BufferList jump_buf_; + BufferList drift_buf_; }; } // namespace webrtc diff --git a/modules/video_coding/rtt_filter_unittest.cc b/modules/video_coding/rtt_filter_unittest.cc new file mode 100644 index 0000000000..15d7d66b83 --- /dev/null +++ b/modules/video_coding/rtt_filter_unittest.cc @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/video_coding/rtt_filter.h" + +#include "api/units/time_delta.h" +#include "test/gmock.h" +#include "test/gtest.h" + +namespace webrtc { + +TEST(RttFilterTest, RttIsCapped) { + VCMRttFilter rtt_filter; + rtt_filter.Update(TimeDelta::Seconds(500)); + + EXPECT_EQ(rtt_filter.Rtt(), TimeDelta::Seconds(3)); +} + +// If the difference between samples is more than away 2.5 stddev from the mean +// then this is considered a jump. After more than 5 data points at the new +// level, the RTT is reset to the new level. +TEST(RttFilterTest, PositiveJumpDetection) { + VCMRttFilter rtt_filter; + + rtt_filter.Update(TimeDelta::Millis(200)); + rtt_filter.Update(TimeDelta::Millis(200)); + rtt_filter.Update(TimeDelta::Millis(200)); + + // Trigger 5 jumps. + rtt_filter.Update(TimeDelta::Millis(1400)); + rtt_filter.Update(TimeDelta::Millis(1500)); + rtt_filter.Update(TimeDelta::Millis(1600)); + rtt_filter.Update(TimeDelta::Millis(1600)); + + EXPECT_EQ(rtt_filter.Rtt(), TimeDelta::Millis(1600)); + + rtt_filter.Update(TimeDelta::Millis(1600)); + EXPECT_EQ(rtt_filter.Rtt(), TimeDelta::Millis(1600)); +} + +TEST(RttFilterTest, NegativeJumpDetection) { + VCMRttFilter rtt_filter; + + for (int i = 0; i < 10; ++i) + rtt_filter.Update(TimeDelta::Millis(1500)); + + // Trigger 5 negative data points that jump rtt down. + rtt_filter.Update(TimeDelta::Millis(200)); + rtt_filter.Update(TimeDelta::Millis(200)); + rtt_filter.Update(TimeDelta::Millis(200)); + rtt_filter.Update(TimeDelta::Millis(200)); + // Before 5 data points at the new level, max RTT is still 1500. + EXPECT_EQ(rtt_filter.Rtt(), TimeDelta::Millis(1500)); + + rtt_filter.Update(TimeDelta::Millis(300)); + EXPECT_EQ(rtt_filter.Rtt(), TimeDelta::Millis(300)); +} + +TEST(RttFilterTest, JumpsResetByDirectionShift) { + VCMRttFilter rtt_filter; + for (int i = 0; i < 10; ++i) + rtt_filter.Update(TimeDelta::Millis(1500)); + + // Trigger 4 negative jumps, then a positive one. This resets the jump + // detection. + rtt_filter.Update(TimeDelta::Millis(200)); + rtt_filter.Update(TimeDelta::Millis(200)); + rtt_filter.Update(TimeDelta::Millis(200)); + rtt_filter.Update(TimeDelta::Millis(200)); + rtt_filter.Update(TimeDelta::Millis(2000)); + EXPECT_EQ(rtt_filter.Rtt(), TimeDelta::Millis(2000)); + + rtt_filter.Update(TimeDelta::Millis(300)); + EXPECT_EQ(rtt_filter.Rtt(), TimeDelta::Millis(2000)); +} + +// If the difference between the max and average is more than 3.5 stddevs away +// then a drift is detected, and a short filter is applied to find a new max +// rtt. +TEST(RttFilterTest, DriftDetection) { + VCMRttFilter rtt_filter; + + // Descend RTT by 30ms and settle at 700ms RTT. A drift is detected after rtt + // of 700ms is reported around 50 times for these targets. + constexpr TimeDelta kStartRtt = TimeDelta::Millis(1000); + constexpr TimeDelta kDriftTarget = TimeDelta::Millis(700); + constexpr TimeDelta kDelta = TimeDelta::Millis(30); + for (TimeDelta rtt = kStartRtt; rtt >= kDriftTarget; rtt -= kDelta) + rtt_filter.Update(rtt); + + EXPECT_EQ(rtt_filter.Rtt(), kStartRtt); + + for (int i = 0; i < 50; ++i) + rtt_filter.Update(kDriftTarget); + EXPECT_EQ(rtt_filter.Rtt(), kDriftTarget); +} + +} // namespace webrtc