diff --git a/rtc_base/rate_statistics.cc b/rtc_base/rate_statistics.cc index 89f7e54a68..c4c2e78581 100644 --- a/rtc_base/rate_statistics.cc +++ b/rtc_base/rate_statistics.cc @@ -15,6 +15,8 @@ #include #include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "rtc_base/numerics/safe_conversions.h" namespace webrtc { @@ -30,6 +32,7 @@ RateStatistics::RateStatistics(int64_t window_size_ms, float scale) RateStatistics::RateStatistics(const RateStatistics& other) : accumulated_count_(other.accumulated_count_), + overflow_(other.overflow_), num_samples_(other.num_samples_), oldest_time_(other.oldest_time_), oldest_index_(other.oldest_index_), @@ -47,6 +50,7 @@ RateStatistics::~RateStatistics() {} void RateStatistics::Reset() { accumulated_count_ = 0; + overflow_ = false; num_samples_ = 0; oldest_time_ = -max_window_size_ms_; oldest_index_ = 0; @@ -55,7 +59,8 @@ void RateStatistics::Reset() { buckets_[i] = Bucket(); } -void RateStatistics::Update(size_t count, int64_t now_ms) { +void RateStatistics::Update(int64_t count, int64_t now_ms) { + RTC_DCHECK_LE(0, count); if (now_ms < oldest_time_) { // Too old data is ignored. return; @@ -67,38 +72,45 @@ void RateStatistics::Update(size_t count, int64_t now_ms) { if (!IsInitialized()) oldest_time_ = now_ms; - uint32_t now_offset = static_cast(now_ms - oldest_time_); + uint32_t now_offset = rtc::dchecked_cast(now_ms - oldest_time_); RTC_DCHECK_LT(now_offset, max_window_size_ms_); uint32_t index = oldest_index_ + now_offset; if (index >= max_window_size_ms_) index -= max_window_size_ms_; buckets_[index].sum += count; ++buckets_[index].samples; - accumulated_count_ += count; + if (std::numeric_limits::max() - accumulated_count_ > count) { + accumulated_count_ += count; + } else { + overflow_ = true; + } ++num_samples_; } -absl::optional RateStatistics::Rate(int64_t now_ms) const { +absl::optional RateStatistics::Rate(int64_t now_ms) const { // Yeah, this const_cast ain't pretty, but the alternative is to declare most // of the members as mutable... const_cast(this)->EraseOld(now_ms); // If window is a single bucket or there is only one sample in a data set that - // has not grown to the full window size, treat this as rate unavailable. - int64_t active_window_size = now_ms - oldest_time_ + 1; + // has not grown to the full window size, or if the accumulator has + // overflowed, treat this as rate unavailable. + int active_window_size = now_ms - oldest_time_ + 1; if (num_samples_ == 0 || active_window_size <= 1 || - (num_samples_ <= 1 && active_window_size < current_window_size_ms_)) { + (num_samples_ <= 1 && + rtc::SafeLt(active_window_size, current_window_size_ms_)) || + overflow_) { return absl::nullopt; } - float scale = scale_ / active_window_size; + float scale = static_cast(scale_) / active_window_size; float result = accumulated_count_ * scale + 0.5f; // Better return unavailable rate than garbage value (undefined behavior). - if (result > std::numeric_limits::max()) { + if (result > static_cast(std::numeric_limits::max())) { return absl::nullopt; } - return static_cast(result); + return rtc::dchecked_cast(result); } void RateStatistics::EraseOld(int64_t now_ms) { @@ -123,6 +135,8 @@ void RateStatistics::EraseOld(int64_t now_ms) { if (++oldest_index_ >= max_window_size_ms_) oldest_index_ = 0; ++oldest_time_; + // This does not clear overflow_ even when counter is empty. + // TODO(https://bugs.webrtc.org/11247): Consider if overflow_ can be reset. } oldest_time_ = new_oldest_time; } @@ -130,7 +144,6 @@ void RateStatistics::EraseOld(int64_t now_ms) { bool RateStatistics::SetWindowSize(int64_t window_size_ms, int64_t now_ms) { if (window_size_ms <= 0 || window_size_ms > max_window_size_ms_) return false; - current_window_size_ms_ = window_size_ms; EraseOld(now_ms); return true; diff --git a/rtc_base/rate_statistics.h b/rtc_base/rate_statistics.h index 65b5fa10d9..11c8cee7af 100644 --- a/rtc_base/rate_statistics.h +++ b/rtc_base/rate_statistics.h @@ -21,6 +21,13 @@ namespace webrtc { +// Class to estimate rates based on counts in a sequence of 1-millisecond +// intervals. + +// This class uses int64 for all its numbers because some rates can be very +// high; for instance, a 20 Mbit/sec video stream can wrap a 32-bit byte +// counter in 14 minutes. + class RTC_EXPORT RateStatistics { public: static constexpr float kBpsScale = 8000.0f; @@ -42,7 +49,7 @@ class RTC_EXPORT RateStatistics { void Reset(); // Update rate with a new data point, moving averaging window as needed. - void Update(size_t count, int64_t now_ms); + void Update(int64_t count, int64_t now_ms); // Note that despite this being a const method, it still updates the internal // state (moves averaging window), but it doesn't make any alterations that @@ -50,7 +57,7 @@ class RTC_EXPORT RateStatistics { // from a monotonic clock. Ie, it doesn't matter if this call moves the // window, since any subsequent call to Update or Rate would still have moved // the window as much or more. - absl::optional Rate(int64_t now_ms) const; + absl::optional Rate(int64_t now_ms) const; // Update the size of the averaging window. The maximum allowed value for // window_size_ms is max_window_size_ms as supplied in the constructor. @@ -63,22 +70,26 @@ class RTC_EXPORT RateStatistics { // Counters are kept in buckets (circular buffer), with one bucket // per millisecond. struct Bucket { - size_t sum; // Sum of all samples in this bucket. - size_t samples; // Number of samples in this bucket. + int64_t sum; // Sum of all samples in this bucket. + int samples; // Number of samples in this bucket. }; std::unique_ptr buckets_; // Total count recorded in buckets. - size_t accumulated_count_; + int64_t accumulated_count_; + + // True if accumulated_count_ has ever grown too large to be + // contained in its integer type. + bool overflow_ = false; // The total number of samples in the buckets. - size_t num_samples_; + int num_samples_; // Oldest time recorded in buckets. int64_t oldest_time_; // Bucket index of oldest counter recorded in buckets. - uint32_t oldest_index_; + int64_t oldest_index_; // To convert counts/ms to desired units const float scale_; diff --git a/rtc_base/rate_statistics_unittest.cc b/rtc_base/rate_statistics_unittest.cc index 9dd82327ba..735677082b 100644 --- a/rtc_base/rate_statistics_unittest.cc +++ b/rtc_base/rate_statistics_unittest.cc @@ -278,4 +278,33 @@ TEST_F(RateStatisticsTest, HandlesQuietPeriods) { EXPECT_TRUE(static_cast(bitrate)); EXPECT_EQ(0u, *bitrate); } + +TEST_F(RateStatisticsTest, HandlesBigNumbers) { + int64_t large_number = 0x100000000u; + int64_t now_ms = 0; + stats_.Update(large_number, now_ms++); + stats_.Update(large_number, now_ms); + EXPECT_TRUE(stats_.Rate(now_ms)); + EXPECT_EQ(large_number * RateStatistics::kBpsScale, *stats_.Rate(now_ms)); +} + +TEST_F(RateStatisticsTest, HandlesTooLargeNumbers) { + int64_t very_large_number = std::numeric_limits::max(); + int64_t now_ms = 0; + stats_.Update(very_large_number, now_ms++); + stats_.Update(very_large_number, now_ms); + // This should overflow the internal accumulator. + EXPECT_FALSE(stats_.Rate(now_ms)); +} + +TEST_F(RateStatisticsTest, HandlesSomewhatLargeNumbers) { + int64_t very_large_number = std::numeric_limits::max(); + int64_t now_ms = 0; + stats_.Update(very_large_number / 4, now_ms++); + stats_.Update(very_large_number / 4, now_ms); + // This should generate a rate of more than int64_t max, but still + // accumulate less than int64_t overflow. + EXPECT_FALSE(stats_.Rate(now_ms)); +} + } // namespace diff --git a/rtc_base/rate_tracker.cc b/rtc_base/rate_tracker.cc index 771dc6c148..5c827927f6 100644 --- a/rtc_base/rate_tracker.cc +++ b/rtc_base/rate_tracker.cc @@ -22,7 +22,7 @@ static const int64_t kTimeUnset = -1; RateTracker::RateTracker(int64_t bucket_milliseconds, size_t bucket_count) : bucket_milliseconds_(bucket_milliseconds), bucket_count_(bucket_count), - sample_buckets_(new size_t[bucket_count + 1]), + sample_buckets_(new int64_t[bucket_count + 1]), total_sample_count_(0u), bucket_start_time_milliseconds_(kTimeUnset) { RTC_CHECK(bucket_milliseconds > 0); @@ -76,10 +76,10 @@ double RateTracker::ComputeRateForInterval( size_t start_bucket = NextBucketIndex(current_bucket_ + buckets_to_skip); // Only count a portion of the first bucket according to how much of the // first bucket is within the current interval. - size_t total_samples = ((sample_buckets_[start_bucket] * - (bucket_milliseconds_ - milliseconds_to_skip)) + - (bucket_milliseconds_ >> 1)) / - bucket_milliseconds_; + int64_t total_samples = ((sample_buckets_[start_bucket] * + (bucket_milliseconds_ - milliseconds_to_skip)) + + (bucket_milliseconds_ >> 1)) / + bucket_milliseconds_; // All other buckets in the interval are counted in their entirety. for (size_t i = NextBucketIndex(start_bucket); i != NextBucketIndex(current_bucket_); i = NextBucketIndex(i)) { @@ -103,11 +103,12 @@ double RateTracker::ComputeTotalRate() const { TimeDiff(current_time, initialization_time_milliseconds_)); } -size_t RateTracker::TotalSampleCount() const { +int64_t RateTracker::TotalSampleCount() const { return total_sample_count_; } -void RateTracker::AddSamples(size_t sample_count) { +void RateTracker::AddSamples(int64_t sample_count) { + RTC_DCHECK_LE(0, sample_count); EnsureInitialized(); int64_t current_time = Time(); // Advance the current bucket as needed for the current time, and reset diff --git a/rtc_base/rate_tracker.h b/rtc_base/rate_tracker.h index e9be52260c..e42d40f14f 100644 --- a/rtc_base/rate_tracker.h +++ b/rtc_base/rate_tracker.h @@ -41,11 +41,11 @@ class RateTracker { double ComputeTotalRate() const; // The total number of samples added. - size_t TotalSampleCount() const; + int64_t TotalSampleCount() const; // Reads the current time in order to determine the appropriate bucket for // these samples, and increments the count for that bucket by sample_count. - void AddSamples(size_t sample_count); + void AddSamples(int64_t sample_count); protected: // overrideable for tests @@ -57,7 +57,7 @@ class RateTracker { const int64_t bucket_milliseconds_; const size_t bucket_count_; - size_t* sample_buckets_; + int64_t* sample_buckets_; size_t total_sample_count_; size_t current_bucket_; int64_t bucket_start_time_milliseconds_; diff --git a/rtc_base/rate_tracker_unittest.cc b/rtc_base/rate_tracker_unittest.cc index 7a2c1ad73b..22ae2c07e7 100644 --- a/rtc_base/rate_tracker_unittest.cc +++ b/rtc_base/rate_tracker_unittest.cc @@ -166,4 +166,13 @@ TEST(RateTrackerTest, TestGetUnitSecondsAfterInitialValue) { EXPECT_DOUBLE_EQ(1234.0, tracker.ComputeRateForInterval(1000)); } +TEST(RateTrackerTest, TestLargeNumbers) { + RateTrackerForTest tracker; + const uint64_t large_number = 0x100000000; + tracker.AddSamples(large_number); + tracker.AdvanceTime(1000); + tracker.AddSamples(large_number); + EXPECT_DOUBLE_EQ(large_number * 2, tracker.ComputeRate()); +} + } // namespace rtc