diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn index 2e7017d1fe..6a2c5aaff5 100644 --- a/rtc_base/BUILD.gn +++ b/rtc_base/BUILD.gn @@ -588,10 +588,12 @@ rtc_static_library("rtc_numerics") { sources = [ "numerics/exp_filter.cc", "numerics/exp_filter.h", + "numerics/math_utils.h", "numerics/moving_average.cc", "numerics/moving_average.h", "numerics/moving_median_filter.h", "numerics/percentile_filter.h", + "numerics/running_statistics.h", "numerics/samples_stats_counter.cc", "numerics/samples_stats_counter.h", "numerics/sequence_number_util.h", @@ -1297,6 +1299,7 @@ if (rtc_include_tests) { "numerics/moving_average_unittest.cc", "numerics/moving_median_filter_unittest.cc", "numerics/percentile_filter_unittest.cc", + "numerics/running_statistics_unittest.cc", "numerics/samples_stats_counter_unittest.cc", "numerics/sequence_number_util_unittest.cc", ] diff --git a/rtc_base/numerics/math_utils.h b/rtc_base/numerics/math_utils.h index 8a91958375..d5f3ee4073 100644 --- a/rtc_base/numerics/math_utils.h +++ b/rtc_base/numerics/math_utils.h @@ -36,4 +36,39 @@ typename std::make_unsigned::type unsigned_difference(T x, T y) { return static_cast(x) - static_cast(y); } +// Provide neutral element with respect to min(). +// Typically used as an initial value for running minimum. +template ::has_infinity>::type* = + nullptr> +constexpr T infinity_or_max() { + return std::numeric_limits::infinity(); +} + +template ::has_infinity>::type* = nullptr> +constexpr T infinity_or_max() { + // Fallback to max(). + return std::numeric_limits::max(); +} + +// Provide neutral element with respect to max(). +// Typically used as an initial value for running maximum. +template ::has_infinity>::type* = + nullptr> +constexpr T minus_infinity_or_min() { + static_assert(std::is_signed::value, "Unsupported. Please open a bug."); + return -std::numeric_limits::infinity(); +} + +template ::has_infinity>::type* = nullptr> +constexpr T minus_infinity_or_min() { + // Fallback to min(). + return std::numeric_limits::min(); +} + #endif // RTC_BASE_NUMERICS_MATH_UTILS_H_ diff --git a/rtc_base/numerics/running_statistics.h b/rtc_base/numerics/running_statistics.h new file mode 100644 index 0000000000..d71323efb0 --- /dev/null +++ b/rtc_base/numerics/running_statistics.h @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef RTC_BASE_NUMERICS_RUNNING_STATISTICS_H_ +#define RTC_BASE_NUMERICS_RUNNING_STATISTICS_H_ + +#include +#include +#include + +#include "absl/types/optional.h" + +#include "rtc_base/numerics/math_utils.h" + +namespace webrtc { + +// tl;dr: Robust and efficient online computation of statistics, +// using Welford's method for variance. [1] +// +// This should be your go-to class if you ever need to compute +// min, max, mean, variance and standard deviation. +// If you need to get percentiles, please use webrtc::SamplesStatsCounter. +// +// The measures return absl::nullopt if no samples were fed (Size() == 0), +// otherwise the returned optional is guaranteed to contain a value. +// +// [1] +// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + +// The type T is a scalar which must be convertible to double. +// Rationale: we often need greater precision for measures +// than for the samples themselves. +template +class RunningStatistics { + public: + // Update stats //////////////////////////////////////////// + + // Add a value participating in the statistics in O(1) time. + void AddSample(T sample) { + max_ = std::max(max_, sample); + min_ = std::min(min_, sample); + ++size_; + // Welford's incremental update. + const double delta = sample - mean_; + mean_ += delta / size_; + const double delta2 = sample - mean_; + cumul_ += delta * delta2; + } + + // Merge other stats, as if samples were added one by one, but in O(1). + void MergeStatistics(const RunningStatistics& other) { + if (other.size_ == 0) { + return; + } + max_ = std::max(max_, other.max_); + min_ = std::min(min_, other.min_); + const int64_t new_size = size_ + other.size_; + const double new_mean = + (mean_ * size_ + other.mean_ * other.size_) / new_size; + // Each cumulant must be corrected. + // * from: sum((x_i - mean_)²) + // * to: sum((x_i - new_mean)²) + auto delta = [new_mean](const RunningStatistics& stats) { + return stats.size_ * (new_mean * (new_mean - 2 * stats.mean_) + + stats.mean_ * stats.mean_); + }; + cumul_ = cumul_ + delta(*this) + other.cumul_ + delta(other); + mean_ = new_mean; + size_ = new_size; + } + + // Get Measures //////////////////////////////////////////// + + // Returns number of samples involved, + // that is number of times AddSample() was called. + int64_t Size() const { return size_; } + + // Returns min in O(1) time. + absl::optional GetMin() const { + if (size_ == 0) { + return absl::nullopt; + } + return min_; + } + + // Returns max in O(1) time. + absl::optional GetMax() const { + if (size_ == 0) { + return absl::nullopt; + } + return max_; + } + + // Returns mean in O(1) time. + absl::optional GetMean() const { + if (size_ == 0) { + return absl::nullopt; + } + return mean_; + } + + // Returns unbiased sample variance in O(1) time. + absl::optional GetVariance() const { + if (size_ == 0) { + return absl::nullopt; + } + return cumul_ / size_; + } + + // Returns unbiased standard deviation in O(1) time. + absl::optional GetStandardDeviation() const { + if (size_ == 0) { + return absl::nullopt; + } + return std::sqrt(*GetVariance()); + } + + private: + int64_t size_ = 0; // Samples seen. + T min_ = infinity_or_max(); + T max_ = minus_infinity_or_min(); + double mean_ = 0; + double cumul_ = 0; // Variance * size_, sometimes noted m2. +}; + +} // namespace webrtc + +#endif // RTC_BASE_NUMERICS_RUNNING_STATISTICS_H_ diff --git a/rtc_base/numerics/running_statistics_unittest.cc b/rtc_base/numerics/running_statistics_unittest.cc new file mode 100644 index 0000000000..806b1e3a1c --- /dev/null +++ b/rtc_base/numerics/running_statistics_unittest.cc @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "rtc_base/numerics/running_statistics.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "test/gtest.h" + +// Tests were copied from samples_stats_counter_unittest.cc. + +namespace webrtc { +namespace { + +RunningStatistics CreateStatsFilledWithIntsFrom1ToN(int n) { + std::vector data; + for (int i = 1; i <= n; i++) { + data.push_back(i); + } + absl::c_shuffle(data, std::mt19937(std::random_device()())); + + RunningStatistics stats; + for (double v : data) { + stats.AddSample(v); + } + return stats; +} + +// Add n samples drawn from uniform distribution in [a;b]. +RunningStatistics CreateStatsFromUniformDistribution(int n, + double a, + double b) { + std::mt19937 gen{std::random_device()()}; + std::uniform_real_distribution<> dis(a, b); + + RunningStatistics stats; + for (int i = 1; i <= n; i++) { + stats.AddSample(dis(gen)); + } + return stats; +} + +class RunningStatisticsTest : public ::testing::TestWithParam {}; + +constexpr int SIZE_FOR_MERGE = 5; + +} // namespace + +TEST(RunningStatisticsTest, FullSimpleTest) { + auto stats = CreateStatsFilledWithIntsFrom1ToN(100); + + EXPECT_DOUBLE_EQ(*stats.GetMin(), 1.0); + EXPECT_DOUBLE_EQ(*stats.GetMax(), 100.0); + EXPECT_DOUBLE_EQ(*stats.GetMean(), 50.5); +} + +TEST(RunningStatistics, VarianceAndDeviation) { + RunningStatistics stats; + stats.AddSample(2); + stats.AddSample(2); + stats.AddSample(-1); + stats.AddSample(5); + + EXPECT_DOUBLE_EQ(*stats.GetMean(), 2.0); + EXPECT_DOUBLE_EQ(*stats.GetVariance(), 4.5); + EXPECT_DOUBLE_EQ(*stats.GetStandardDeviation(), sqrt(4.5)); +} + +TEST(RunningStatisticsTest, VarianceFromUniformDistribution) { + // Check variance converge to 1/12 for [0;1) uniform distribution. + // Acts as a sanity check for NumericStabilityForVariance test. + auto stats = CreateStatsFromUniformDistribution(1e6, 0, 1); + + EXPECT_NEAR(*stats.GetVariance(), 1. / 12, 1e-3); +} + +TEST(RunningStatisticsTest, NumericStabilityForVariance) { + // Same test as VarianceFromUniformDistribution, + // except the range is shifted to [1e9;1e9+1). + // Variance should also converge to 1/12. + // NB: Although we lose precision for the samples themselves, the fractional + // part still enjoys 22 bits of mantissa and errors should even out, + // so that couldn't explain a mismatch. + auto stats = CreateStatsFromUniformDistribution(1e6, 1e9, 1e9 + 1); + + EXPECT_NEAR(*stats.GetVariance(), 1. / 12, 1e-3); +} + +TEST_P(RunningStatisticsTest, MergeStatistics) { + int data[SIZE_FOR_MERGE] = {2, 2, -1, 5, 10}; + // Split the data in different partitions. + // We have 6 distinct tests: + // * Empty merged with full sequence. + // * 1 sample merged with 4 last. + // * 2 samples merged with 3 last. + // [...] + // * Full merged with empty sequence. + // All must lead to the same result. + // I miss QuickCheck so much. + RunningStatistics stats0, stats1; + for (int i = 0; i < GetParam(); ++i) { + stats0.AddSample(data[i]); + } + for (int i = GetParam(); i < SIZE_FOR_MERGE; ++i) { + stats1.AddSample(data[i]); + } + stats0.MergeStatistics(stats1); + + EXPECT_EQ(stats0.Size(), SIZE_FOR_MERGE); + EXPECT_DOUBLE_EQ(*stats0.GetMin(), -1); + EXPECT_DOUBLE_EQ(*stats0.GetMax(), 10); + EXPECT_DOUBLE_EQ(*stats0.GetMean(), 3.6); + EXPECT_DOUBLE_EQ(*stats0.GetVariance(), 13.84); + EXPECT_DOUBLE_EQ(*stats0.GetStandardDeviation(), sqrt(13.84)); +} + +INSTANTIATE_TEST_SUITE_P(RunningStatisticsTests, + RunningStatisticsTest, + ::testing::Range(0, SIZE_FOR_MERGE + 1)); + +} // namespace webrtc diff --git a/rtc_base/numerics/samples_stats_counter.cc b/rtc_base/numerics/samples_stats_counter.cc index 134a65db1c..655f4c1fe6 100644 --- a/rtc_base/numerics/samples_stats_counter.cc +++ b/rtc_base/numerics/samples_stats_counter.cc @@ -26,26 +26,15 @@ SamplesStatsCounter& SamplesStatsCounter::operator=(SamplesStatsCounter&&) = default; void SamplesStatsCounter::AddSample(double value) { + stats_.AddSample(value); samples_.push_back(value); sorted_ = false; - if (value > max_) { - max_ = value; - } - if (value < min_) { - min_ = value; - } - sum_ += value; - sum_squared_ += value * value; } void SamplesStatsCounter::AddSamples(const SamplesStatsCounter& other) { - for (double sample : other.samples_) - samples_.push_back(sample); + stats_.MergeStatistics(other.stats_); + samples_.insert(samples_.end(), other.samples_.begin(), other.samples_.end()); sorted_ = false; - max_ = std::max(max_, other.max_); - min_ = std::min(min_, other.min_); - sum_ += other.sum_; - sum_squared_ += other.sum_squared_; } double SamplesStatsCounter::GetPercentile(double percentile) { diff --git a/rtc_base/numerics/samples_stats_counter.h b/rtc_base/numerics/samples_stats_counter.h index 05a8c145f5..ac5f12cdea 100644 --- a/rtc_base/numerics/samples_stats_counter.h +++ b/rtc_base/numerics/samples_stats_counter.h @@ -11,14 +11,15 @@ #ifndef RTC_BASE_NUMERICS_SAMPLES_STATS_COUNTER_H_ #define RTC_BASE_NUMERICS_SAMPLES_STATS_COUNTER_H_ -#include -#include #include #include "rtc_base/checks.h" +#include "rtc_base/numerics/running_statistics.h" namespace webrtc { +// This class extends RunningStatistics by providing GetPercentile() method, +// while slightly adapting the interface. class SamplesStatsCounter { public: SamplesStatsCounter(); @@ -41,31 +42,31 @@ class SamplesStatsCounter { // samples. double GetMin() const { RTC_DCHECK(!IsEmpty()); - return min_; + return *stats_.GetMin(); } // Returns max in O(1) time. This function may not be called if there are no // samples. double GetMax() const { RTC_DCHECK(!IsEmpty()); - return max_; + return *stats_.GetMax(); } // Returns average in O(1) time. This function may not be called if there are // no samples. double GetAverage() const { RTC_DCHECK(!IsEmpty()); - return sum_ / samples_.size(); + return *stats_.GetMean(); } // Returns variance in O(1) time. This function may not be called if there are // no samples. double GetVariance() const { RTC_DCHECK(!IsEmpty()); - return sum_squared_ / samples_.size() - GetAverage() * GetAverage(); + return *stats_.GetVariance(); } // Returns standard deviation in O(1) time. This function may not be called if // there are no samples. double GetStandardDeviation() const { RTC_DCHECK(!IsEmpty()); - return sqrt(GetVariance()); + return *stats_.GetStandardDeviation(); } // Returns percentile in O(nlogn) on first call and in O(1) after, if no // additions were done. This function may not be called if there are no @@ -76,11 +77,8 @@ class SamplesStatsCounter { double GetPercentile(double percentile); private: + RunningStatistics stats_; std::vector samples_; - double min_ = std::numeric_limits::max(); - double max_ = std::numeric_limits::min(); - double sum_ = 0; - double sum_squared_ = 0; bool sorted_ = false; }; diff --git a/rtc_base/numerics/samples_stats_counter_unittest.cc b/rtc_base/numerics/samples_stats_counter_unittest.cc index 8634295c8e..590bf8c785 100644 --- a/rtc_base/numerics/samples_stats_counter_unittest.cc +++ b/rtc_base/numerics/samples_stats_counter_unittest.cc @@ -34,6 +34,24 @@ SamplesStatsCounter CreateStatsFilledWithIntsFrom1ToN(int n) { return stats; } +// Add n samples drawn from uniform distribution in [a;b]. +SamplesStatsCounter CreateStatsFromUniformDistribution(int n, + double a, + double b) { + std::mt19937 gen{std::random_device()()}; + std::uniform_real_distribution<> dis(a, b); + + SamplesStatsCounter stats; + for (int i = 1; i <= n; i++) { + stats.AddSample(dis(gen)); + } + return stats; +} + +class SamplesStatsCounterTest : public ::testing::TestWithParam {}; + +constexpr int SIZE_FOR_MERGE = 10; + } // namespace TEST(SamplesStatsCounter, FullSimpleTest) { @@ -76,4 +94,58 @@ TEST(SamplesStatsCounter, TestBorderValues) { EXPECT_DOUBLE_EQ(stats.GetPercentile(1.0), 5); } +TEST(SamplesStatsCounter, VarianceFromUniformDistribution) { + // Check variance converge to 1/12 for [0;1) uniform distribution. + // Acts as a sanity check for NumericStabilityForVariance test. + SamplesStatsCounter stats = CreateStatsFromUniformDistribution(1e6, 0, 1); + + EXPECT_NEAR(stats.GetVariance(), 1. / 12, 1e-3); +} + +TEST(SamplesStatsCounter, NumericStabilityForVariance) { + // Same test as VarianceFromUniformDistribution, + // except the range is shifted to [1e9;1e9+1). + // Variance should also converge to 1/12. + // NB: Although we lose precision for the samples themselves, the fractional + // part still enjoys 22 bits of mantissa and errors should even out, + // so that couldn't explain a mismatch. + SamplesStatsCounter stats = + CreateStatsFromUniformDistribution(1e6, 1e9, 1e9 + 1); + + EXPECT_NEAR(stats.GetVariance(), 1. / 12, 1e-3); +} + +TEST_P(SamplesStatsCounterTest, AddSamples) { + int data[SIZE_FOR_MERGE] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + // Split the data in different partitions. + // We have 11 distinct tests: + // * Empty merged with full sequence. + // * 1 sample merged with 9 last. + // * 2 samples merged with 8 last. + // [...] + // * Full merged with empty sequence. + // All must lead to the same result. + SamplesStatsCounter stats0, stats1; + for (int i = 0; i < GetParam(); ++i) { + stats0.AddSample(data[i]); + } + for (int i = GetParam(); i < SIZE_FOR_MERGE; ++i) { + stats1.AddSample(data[i]); + } + stats0.AddSamples(stats1); + + EXPECT_EQ(stats0.GetMin(), 0); + EXPECT_EQ(stats0.GetMax(), 9); + EXPECT_DOUBLE_EQ(stats0.GetAverage(), 4.5); + EXPECT_DOUBLE_EQ(stats0.GetVariance(), 8.25); + EXPECT_DOUBLE_EQ(stats0.GetStandardDeviation(), sqrt(8.25)); + EXPECT_DOUBLE_EQ(stats0.GetPercentile(0.1), 0.9); + EXPECT_DOUBLE_EQ(stats0.GetPercentile(0.5), 4.5); + EXPECT_DOUBLE_EQ(stats0.GetPercentile(0.9), 8.1); +} + +INSTANTIATE_TEST_SUITE_P(SamplesStatsCounterTests, + SamplesStatsCounterTest, + ::testing::Range(0, SIZE_FOR_MERGE + 1)); + } // namespace webrtc