From 70b775d77f9de33cc4f48437e62bc9192a21001b Mon Sep 17 00:00:00 2001 From: Alessio Bazzica Date: Wed, 7 Apr 2021 12:03:11 +0200 Subject: [PATCH] AGC2 noise estimator code style improvements Code style improvements done in preparation for a bug fix (TODO added) which requires changes in the unit tests. Note that one expected value in the unit tests has been adjusted since the white noise generator is now instanced in each separate test and therefore, even if the seed remained the same, the generated sequences differ. Bug: webrtc:7494 Change-Id: I497513b84f50b5c66cf6241a09946ce853eb1cd2 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/214122 Commit-Queue: Alessio Bazzica Reviewed-by: Ivo Creusen Cr-Commit-Position: refs/heads/master@{#33636} --- modules/audio_processing/agc2/BUILD.gn | 1 + .../agc2/agc2_testing_common.cc | 72 ++++++++++++++++-- .../agc2/agc2_testing_common.h | 74 ++++++++++--------- modules/audio_processing/agc2/down_sampler.cc | 4 +- modules/audio_processing/agc2/down_sampler.h | 2 +- .../agc2/noise_level_estimator.cc | 48 +++++++----- .../agc2/noise_level_estimator.h | 6 +- .../agc2/noise_level_estimator_unittest.cc | 64 ++++++++-------- .../agc2/noise_spectrum_estimator.cc | 4 +- .../agc2/signal_classifier.cc | 4 +- .../agc2/signal_classifier_unittest.cc | 60 ++++++++------- 11 files changed, 212 insertions(+), 127 deletions(-) diff --git a/modules/audio_processing/agc2/BUILD.gn b/modules/audio_processing/agc2/BUILD.gn index 8f2ee0fddd..910b58c9c2 100644 --- a/modules/audio_processing/agc2/BUILD.gn +++ b/modules/audio_processing/agc2/BUILD.gn @@ -248,6 +248,7 @@ rtc_library("noise_estimator_unittests") { "..:apm_logging", "..:audio_frame_view", "../../../api:array_view", + "../../../api:function_view", "../../../rtc_base:checks", "../../../rtc_base:gunit_helpers", "../../../rtc_base:rtc_base_approved", diff --git a/modules/audio_processing/agc2/agc2_testing_common.cc b/modules/audio_processing/agc2/agc2_testing_common.cc index 6c22492e88..125e551b72 100644 --- a/modules/audio_processing/agc2/agc2_testing_common.cc +++ b/modules/audio_processing/agc2/agc2_testing_common.cc @@ -10,24 +10,84 @@ #include "modules/audio_processing/agc2/agc2_testing_common.h" +#include + #include "rtc_base/checks.h" namespace webrtc { - namespace test { -std::vector LinSpace(const double l, - const double r, - size_t num_points) { - RTC_CHECK(num_points >= 2); +std::vector LinSpace(double l, double r, int num_points) { + RTC_CHECK_GE(num_points, 2); std::vector points(num_points); const double step = (r - l) / (num_points - 1.0); points[0] = l; - for (size_t i = 1; i < num_points - 1; i++) { + for (int i = 1; i < num_points - 1; i++) { points[i] = static_cast(l) + i * step; } points[num_points - 1] = r; return points; } + +WhiteNoiseGenerator::WhiteNoiseGenerator(int min_amplitude, int max_amplitude) + : rand_gen_(42), + min_amplitude_(min_amplitude), + max_amplitude_(max_amplitude) { + RTC_DCHECK_LT(min_amplitude_, max_amplitude_); + RTC_DCHECK_LE(kMinS16, min_amplitude_); + RTC_DCHECK_LE(min_amplitude_, kMaxS16); + RTC_DCHECK_LE(kMinS16, max_amplitude_); + RTC_DCHECK_LE(max_amplitude_, kMaxS16); +} + +float WhiteNoiseGenerator::operator()() { + return static_cast(rand_gen_.Rand(min_amplitude_, max_amplitude_)); +} + +SineGenerator::SineGenerator(float amplitude, + float frequency_hz, + int sample_rate_hz) + : amplitude_(amplitude), + frequency_hz_(frequency_hz), + sample_rate_hz_(sample_rate_hz), + x_radians_(0.0f) { + RTC_DCHECK_GT(amplitude_, 0); + RTC_DCHECK_LE(amplitude_, kMaxS16); +} + +float SineGenerator::operator()() { + constexpr float kPi = 3.1415926536f; + x_radians_ += frequency_hz_ / sample_rate_hz_ * 2 * kPi; + if (x_radians_ >= 2 * kPi) { + x_radians_ -= 2 * kPi; + } + return amplitude_ * std::sinf(x_radians_); +} + +PulseGenerator::PulseGenerator(float pulse_amplitude, + float no_pulse_amplitude, + float frequency_hz, + int sample_rate_hz) + : pulse_amplitude_(pulse_amplitude), + no_pulse_amplitude_(no_pulse_amplitude), + samples_period_( + static_cast(static_cast(sample_rate_hz) / frequency_hz)), + sample_counter_(0) { + RTC_DCHECK_GE(pulse_amplitude_, kMinS16); + RTC_DCHECK_LE(pulse_amplitude_, kMaxS16); + RTC_DCHECK_GT(no_pulse_amplitude_, kMinS16); + RTC_DCHECK_LE(no_pulse_amplitude_, kMaxS16); + RTC_DCHECK_GT(sample_rate_hz, frequency_hz); +} + +float PulseGenerator::operator()() { + sample_counter_++; + if (sample_counter_ >= samples_period_) { + sample_counter_ -= samples_period_; + } + return static_cast(sample_counter_ == 0 ? pulse_amplitude_ + : no_pulse_amplitude_); +} + } // namespace test } // namespace webrtc diff --git a/modules/audio_processing/agc2/agc2_testing_common.h b/modules/audio_processing/agc2/agc2_testing_common.h index 7bfadbb3fc..4572d9cffd 100644 --- a/modules/audio_processing/agc2/agc2_testing_common.h +++ b/modules/audio_processing/agc2/agc2_testing_common.h @@ -11,17 +11,19 @@ #ifndef MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_ #define MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_ -#include - #include #include -#include "rtc_base/checks.h" +#include "rtc_base/random.h" namespace webrtc { - namespace test { +constexpr float kMinS16 = + static_cast(std::numeric_limits::min()); +constexpr float kMaxS16 = + static_cast(std::numeric_limits::max()); + // Level Estimator test parameters. constexpr float kDecayMs = 500.f; @@ -29,47 +31,49 @@ constexpr float kDecayMs = 500.f; constexpr float kLimiterMaxInputLevelDbFs = 1.f; constexpr float kLimiterKneeSmoothnessDb = 1.f; constexpr float kLimiterCompressionRatio = 5.f; -constexpr float kPi = 3.1415926536f; -std::vector LinSpace(const double l, const double r, size_t num_points); +// Returns evenly spaced `num_points` numbers over a specified interval [l, r]. +std::vector LinSpace(double l, double r, int num_points); -class SineGenerator { +// Generates white noise. +class WhiteNoiseGenerator { public: - SineGenerator(float frequency, int rate) - : frequency_(frequency), rate_(rate) {} - float operator()() { - x_radians_ += frequency_ / rate_ * 2 * kPi; - if (x_radians_ > 2 * kPi) { - x_radians_ -= 2 * kPi; - } - return 1000.f * sinf(x_radians_); - } + WhiteNoiseGenerator(int min_amplitude, int max_amplitude); + float operator()(); private: - float frequency_; - int rate_; - float x_radians_ = 0.f; + Random rand_gen_; + const int min_amplitude_; + const int max_amplitude_; }; -class PulseGenerator { +// Generates a sine function. +class SineGenerator { public: - PulseGenerator(float frequency, int rate) - : samples_period_( - static_cast(static_cast(rate) / frequency)) { - RTC_DCHECK_GT(rate, frequency); - } - float operator()() { - sample_counter_++; - if (sample_counter_ >= samples_period_) { - sample_counter_ -= samples_period_; - } - return static_cast( - sample_counter_ == 0 ? std::numeric_limits::max() : 10.f); - } + SineGenerator(float amplitude, float frequency_hz, int sample_rate_hz); + float operator()(); private: - int samples_period_; - int sample_counter_ = 0; + const float amplitude_; + const float frequency_hz_; + const int sample_rate_hz_; + float x_radians_; +}; + +// Generates periodic pulses. +class PulseGenerator { + public: + PulseGenerator(float pulse_amplitude, + float no_pulse_amplitude, + float frequency_hz, + int sample_rate_hz); + float operator()(); + + private: + const float pulse_amplitude_; + const float no_pulse_amplitude_; + const int samples_period_; + int sample_counter_; }; } // namespace test diff --git a/modules/audio_processing/agc2/down_sampler.cc b/modules/audio_processing/agc2/down_sampler.cc index 654ed4be37..fd1a2c3a46 100644 --- a/modules/audio_processing/agc2/down_sampler.cc +++ b/modules/audio_processing/agc2/down_sampler.cc @@ -72,7 +72,7 @@ void DownSampler::Initialize(int sample_rate_hz) { void DownSampler::DownSample(rtc::ArrayView in, rtc::ArrayView out) { - data_dumper_->DumpWav("lc_down_sampler_input", in, sample_rate_hz_, 1); + data_dumper_->DumpWav("agc2_down_sampler_input", in, sample_rate_hz_, 1); RTC_DCHECK_EQ(sample_rate_hz_ * kChunkSizeMs / 1000, in.size()); RTC_DCHECK_EQ(kSampleRate8kHz * kChunkSizeMs / 1000, out.size()); const size_t kMaxNumFrames = kSampleRate48kHz * kChunkSizeMs / 1000; @@ -93,7 +93,7 @@ void DownSampler::DownSample(rtc::ArrayView in, std::copy(in.data(), in.data() + in.size(), out.data()); } - data_dumper_->DumpWav("lc_down_sampler_output", out, kSampleRate8kHz, 1); + data_dumper_->DumpWav("agc2_down_sampler_output", out, kSampleRate8kHz, 1); } } // namespace webrtc diff --git a/modules/audio_processing/agc2/down_sampler.h b/modules/audio_processing/agc2/down_sampler.h index be7cbb3da7..a44f96fa2d 100644 --- a/modules/audio_processing/agc2/down_sampler.h +++ b/modules/audio_processing/agc2/down_sampler.h @@ -31,7 +31,7 @@ class DownSampler { void DownSample(rtc::ArrayView in, rtc::ArrayView out); private: - ApmDataDumper* data_dumper_; + ApmDataDumper* const data_dumper_; int sample_rate_hz_; int down_sampling_factor_; BiQuadFilter low_pass_filter_; diff --git a/modules/audio_processing/agc2/noise_level_estimator.cc b/modules/audio_processing/agc2/noise_level_estimator.cc index 2ca5034334..d50ecbac96 100644 --- a/modules/audio_processing/agc2/noise_level_estimator.cc +++ b/modules/audio_processing/agc2/noise_level_estimator.cc @@ -27,10 +27,10 @@ namespace { constexpr int kFramesPerSecond = 100; float FrameEnergy(const AudioFrameView& audio) { - float energy = 0.f; + float energy = 0.0f; for (size_t k = 0; k < audio.num_channels(); ++k) { float channel_energy = - std::accumulate(audio.channel(k).begin(), audio.channel(k).end(), 0.f, + std::accumulate(audio.channel(k).begin(), audio.channel(k).end(), 0.0f, [](float a, float b) -> float { return a + b * b; }); energy = std::max(channel_energy, energy); } @@ -44,7 +44,7 @@ float EnergyToDbfs(float signal_energy, size_t num_samples) { } // namespace NoiseLevelEstimator::NoiseLevelEstimator(ApmDataDumper* data_dumper) - : signal_classifier_(data_dumper) { + : data_dumper_(data_dumper), signal_classifier_(data_dumper) { Initialize(48000); } @@ -52,35 +52,40 @@ NoiseLevelEstimator::~NoiseLevelEstimator() {} void NoiseLevelEstimator::Initialize(int sample_rate_hz) { sample_rate_hz_ = sample_rate_hz; - noise_energy_ = 1.f; + noise_energy_ = 1.0f; first_update_ = true; - min_noise_energy_ = sample_rate_hz * 2.f * 2.f / kFramesPerSecond; + min_noise_energy_ = sample_rate_hz * 2.0f * 2.0f / kFramesPerSecond; noise_energy_hold_counter_ = 0; signal_classifier_.Initialize(sample_rate_hz); } float NoiseLevelEstimator::Analyze(const AudioFrameView& frame) { - const int rate = + data_dumper_->DumpRaw("agc2_noise_level_estimator_hold_counter", + noise_energy_hold_counter_); + const int sample_rate_hz = static_cast(frame.samples_per_channel() * kFramesPerSecond); - if (rate != sample_rate_hz_) { - Initialize(rate); + if (sample_rate_hz != sample_rate_hz_) { + Initialize(sample_rate_hz); } const float frame_energy = FrameEnergy(frame); if (frame_energy <= 0.f) { RTC_DCHECK_GE(frame_energy, 0.f); + data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type", -1); return EnergyToDbfs(noise_energy_, frame.samples_per_channel()); } if (first_update_) { // Initialize the noise energy to the frame energy. first_update_ = false; - return EnergyToDbfs( - noise_energy_ = std::max(frame_energy, min_noise_energy_), - frame.samples_per_channel()); + data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type", -1); + noise_energy_ = std::max(frame_energy, min_noise_energy_); + return EnergyToDbfs(noise_energy_, frame.samples_per_channel()); } const SignalClassifier::SignalType signal_type = signal_classifier_.Analyze(frame.channel(0)); + data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type", + static_cast(signal_type)); // Update the noise estimate in a minimum statistics-type manner. if (signal_type == SignalClassifier::SignalType::kStationary) { @@ -90,25 +95,32 @@ float NoiseLevelEstimator::Analyze(const AudioFrameView& frame) { noise_energy_hold_counter_ = std::max(noise_energy_hold_counter_ - 1, 0); if (noise_energy_hold_counter_ == 0) { - noise_energy_ = std::min(noise_energy_ * 1.01f, frame_energy); + constexpr float kMaxNoiseEnergyFactor = 1.01f; + noise_energy_ = + std::min(noise_energy_ * kMaxNoiseEnergyFactor, frame_energy); } } else { // Update smoothly downwards with a limited maximum update magnitude. + constexpr float kMinNoiseEnergyFactor = 0.9f; + constexpr float kNoiseEnergyDeltaFactor = 0.05f; noise_energy_ = - std::max(noise_energy_ * 0.9f, - noise_energy_ + 0.05f * (frame_energy - noise_energy_)); - noise_energy_hold_counter_ = 1000; + std::max(noise_energy_ * kMinNoiseEnergyFactor, + noise_energy_ - kNoiseEnergyDeltaFactor * + (noise_energy_ - frame_energy)); + // Prevent an energy increase for the next 10 seconds. + constexpr int kNumFramesToEnergyIncreaseAllowed = 1000; + noise_energy_hold_counter_ = kNumFramesToEnergyIncreaseAllowed; } } else { + // TODO(bugs.webrtc.org/7494): Remove to not forget the estimated level. // For a non-stationary signal, leak the estimate downwards in order to // avoid estimate locking due to incorrect signal classification. noise_energy_ = noise_energy_ * 0.99f; } // Ensure a minimum of the estimate. - return EnergyToDbfs( - noise_energy_ = std::max(noise_energy_, min_noise_energy_), - frame.samples_per_channel()); + noise_energy_ = std::max(noise_energy_, min_noise_energy_); + return EnergyToDbfs(noise_energy_, frame.samples_per_channel()); } } // namespace webrtc diff --git a/modules/audio_processing/agc2/noise_level_estimator.h b/modules/audio_processing/agc2/noise_level_estimator.h index ca2f9f2e2f..65d462342a 100644 --- a/modules/audio_processing/agc2/noise_level_estimator.h +++ b/modules/audio_processing/agc2/noise_level_estimator.h @@ -13,7 +13,6 @@ #include "modules/audio_processing/agc2/signal_classifier.h" #include "modules/audio_processing/include/audio_frame_view.h" -#include "rtc_base/constructor_magic.h" namespace webrtc { class ApmDataDumper; @@ -21,6 +20,8 @@ class ApmDataDumper; class NoiseLevelEstimator { public: NoiseLevelEstimator(ApmDataDumper* data_dumper); + NoiseLevelEstimator(const NoiseLevelEstimator&) = delete; + NoiseLevelEstimator& operator=(const NoiseLevelEstimator&) = delete; ~NoiseLevelEstimator(); // Returns the estimated noise level in dBFS. float Analyze(const AudioFrameView& frame); @@ -28,14 +29,13 @@ class NoiseLevelEstimator { private: void Initialize(int sample_rate_hz); + ApmDataDumper* const data_dumper_; int sample_rate_hz_; float min_noise_energy_; bool first_update_; float noise_energy_; int noise_energy_hold_counter_; SignalClassifier signal_classifier_; - - RTC_DISALLOW_COPY_AND_ASSIGN(NoiseLevelEstimator); }; } // namespace webrtc diff --git a/modules/audio_processing/agc2/noise_level_estimator_unittest.cc b/modules/audio_processing/agc2/noise_level_estimator_unittest.cc index c4fd33b0a0..327fceee8a 100644 --- a/modules/audio_processing/agc2/noise_level_estimator_unittest.cc +++ b/modules/audio_processing/agc2/noise_level_estimator_unittest.cc @@ -14,30 +14,31 @@ #include #include +#include "api/function_view.h" #include "modules/audio_processing/agc2/agc2_testing_common.h" #include "modules/audio_processing/agc2/vector_float_frame.h" #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/gunit.h" -#include "rtc_base/random.h" namespace webrtc { namespace { -Random rand_gen(42); -ApmDataDumper data_dumper(0); + constexpr int kNumIterations = 200; constexpr int kFramesPerSecond = 100; // Runs the noise estimator on audio generated by 'sample_generator' // for kNumIterations. Returns the last noise level estimate. -float RunEstimator(std::function sample_generator, int rate) { +float RunEstimator(rtc::FunctionView sample_generator, + int sample_rate_hz) { + ApmDataDumper data_dumper(0); NoiseLevelEstimator estimator(&data_dumper); - const size_t samples_per_channel = - rtc::CheckedDivExact(rate, kFramesPerSecond); - VectorFloatFrame signal(1, static_cast(samples_per_channel), 0.f); + const int samples_per_channel = + rtc::CheckedDivExact(sample_rate_hz, kFramesPerSecond); + VectorFloatFrame signal(1, samples_per_channel, 0.0f); for (int i = 0; i < kNumIterations; ++i) { AudioFrameView frame_view = signal.float_frame_view(); - for (size_t j = 0; j < samples_per_channel; ++j) { + for (int j = 0; j < samples_per_channel; ++j) { frame_view.channel(0)[j] = sample_generator(); } estimator.Analyze(frame_view); @@ -45,39 +46,42 @@ float RunEstimator(std::function sample_generator, int rate) { return estimator.Analyze(signal.float_frame_view()); } -float WhiteNoiseGenerator() { - return static_cast(rand_gen.Rand(std::numeric_limits::min(), - std::numeric_limits::max())); -} -} // namespace +class NoiseEstimatorParametrization : public ::testing::TestWithParam { + protected: + int sample_rate_hz() const { return GetParam(); } +}; // White random noise is stationary, but does not trigger the detector // every frame due to the randomness. -TEST(AutomaticGainController2NoiseEstimator, RandomNoise) { - for (const auto rate : {8000, 16000, 32000, 48000}) { - const float noise_level = RunEstimator(WhiteNoiseGenerator, rate); - EXPECT_NEAR(noise_level, -5.f, 1.f); - } +TEST_P(NoiseEstimatorParametrization, RandomNoise) { + test::WhiteNoiseGenerator gen(/*min_amplitude=*/test::kMinS16, + /*max_amplitude=*/test::kMaxS16); + const float noise_level_dbfs = RunEstimator(gen, sample_rate_hz()); + EXPECT_NEAR(noise_level_dbfs, -5.5f, 1.0f); } // Sine curves are (very) stationary. They trigger the detector all // the time. Except for a few initial frames. -TEST(AutomaticGainController2NoiseEstimator, SineTone) { - for (const auto rate : {8000, 16000, 32000, 48000}) { - test::SineGenerator gen(600.f, rate); - const float noise_level = RunEstimator(gen, rate); - EXPECT_NEAR(noise_level, -33.f, 1.f); - } +TEST_P(NoiseEstimatorParametrization, SineTone) { + test::SineGenerator gen(/*amplitude=*/test::kMaxS16, /*frequency_hz=*/600.0f, + sample_rate_hz()); + const float noise_level_dbfs = RunEstimator(gen, sample_rate_hz()); + EXPECT_NEAR(noise_level_dbfs, -3.0f, 1.0f); } // Pulses are transient if they are far enough apart. They shouldn't // trigger the noise detector. -TEST(AutomaticGainController2NoiseEstimator, PulseTone) { - for (const auto rate : {8000, 16000, 32000, 48000}) { - test::PulseGenerator gen(20.f, rate); - const int noise_level = RunEstimator(gen, rate); - EXPECT_NEAR(noise_level, -79.f, 1.f); - } +TEST_P(NoiseEstimatorParametrization, PulseTone) { + test::PulseGenerator gen(/*pulse_amplitude=*/test::kMaxS16, + /*no_pulse_amplitude=*/10.0f, /*frequency_hz=*/20.0f, + sample_rate_hz()); + const int noise_level_dbfs = RunEstimator(gen, sample_rate_hz()); + EXPECT_NEAR(noise_level_dbfs, -79.0f, 1.0f); } +INSTANTIATE_TEST_SUITE_P(GainController2NoiseEstimator, + NoiseEstimatorParametrization, + ::testing::Values(8000, 16000, 32000, 48000)); + +} // namespace } // namespace webrtc diff --git a/modules/audio_processing/agc2/noise_spectrum_estimator.cc b/modules/audio_processing/agc2/noise_spectrum_estimator.cc index 31438b1f49..f283f4e27f 100644 --- a/modules/audio_processing/agc2/noise_spectrum_estimator.cc +++ b/modules/audio_processing/agc2/noise_spectrum_estimator.cc @@ -63,8 +63,8 @@ void NoiseSpectrumEstimator::Update(rtc::ArrayView spectrum, v = std::max(v, kMinNoisePower); } - data_dumper_->DumpRaw("lc_noise_spectrum", 65, noise_spectrum_); - data_dumper_->DumpRaw("lc_signal_spectrum", spectrum); + data_dumper_->DumpRaw("agc2_noise_spectrum", 65, noise_spectrum_); + data_dumper_->DumpRaw("agc2_signal_spectrum", spectrum); } } // namespace webrtc diff --git a/modules/audio_processing/agc2/signal_classifier.cc b/modules/audio_processing/agc2/signal_classifier.cc index a06413d166..3ef8dd775b 100644 --- a/modules/audio_processing/agc2/signal_classifier.cc +++ b/modules/audio_processing/agc2/signal_classifier.cc @@ -84,8 +84,8 @@ webrtc::SignalClassifier::SignalType ClassifySignal( } } - data_dumper->DumpRaw("lc_num_stationary_bands", 1, &num_stationary_bands); - data_dumper->DumpRaw("lc_num_highly_nonstationary_bands", 1, + data_dumper->DumpRaw("agc2_num_stationary_bands", 1, &num_stationary_bands); + data_dumper->DumpRaw("agc2_num_highly_nonstationary_bands", 1, &num_highly_nonstationary_bands); // Use the detected number of bands to classify the overall signal diff --git a/modules/audio_processing/agc2/signal_classifier_unittest.cc b/modules/audio_processing/agc2/signal_classifier_unittest.cc index 62171b32e6..f1a3a664f0 100644 --- a/modules/audio_processing/agc2/signal_classifier_unittest.cc +++ b/modules/audio_processing/agc2/signal_classifier_unittest.cc @@ -14,25 +14,25 @@ #include #include +#include "api/function_view.h" #include "modules/audio_processing/agc2/agc2_testing_common.h" #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/gunit.h" #include "rtc_base/random.h" namespace webrtc { - namespace { -Random rand_gen(42); -ApmDataDumper data_dumper(0); constexpr int kNumIterations = 100; // Runs the signal classifier on audio generated by 'sample_generator' // for kNumIterations. Returns the number of frames classified as noise. -int RunClassifier(std::function sample_generator, int rate) { +float RunClassifier(rtc::FunctionView sample_generator, + int sample_rate_hz) { + ApmDataDumper data_dumper(0); SignalClassifier classifier(&data_dumper); std::array signal; - classifier.Initialize(rate); - const size_t samples_per_channel = rtc::CheckedDivExact(rate, 100); + classifier.Initialize(sample_rate_hz); + const size_t samples_per_channel = rtc::CheckedDivExact(sample_rate_hz, 100); int number_of_noise_frames = 0; for (int i = 0; i < kNumIterations; ++i) { for (size_t j = 0; j < samples_per_channel; ++j) { @@ -45,38 +45,42 @@ int RunClassifier(std::function sample_generator, int rate) { return number_of_noise_frames; } -float WhiteNoiseGenerator() { - return static_cast(rand_gen.Rand(std::numeric_limits::min(), - std::numeric_limits::max())); -} -} // namespace +class SignalClassifierParametrization : public ::testing::TestWithParam { + protected: + int sample_rate_hz() const { return GetParam(); } +}; // White random noise is stationary, but does not trigger the detector // every frame due to the randomness. -TEST(AutomaticGainController2SignalClassifier, WhiteNoise) { - for (const auto rate : {8000, 16000, 32000, 48000}) { - const int number_of_noise_frames = RunClassifier(WhiteNoiseGenerator, rate); - EXPECT_GT(number_of_noise_frames, kNumIterations / 2); - } +TEST_P(SignalClassifierParametrization, WhiteNoise) { + test::WhiteNoiseGenerator gen(/*min_amplitude=*/test::kMinS16, + /*max_amplitude=*/test::kMaxS16); + const int number_of_noise_frames = RunClassifier(gen, sample_rate_hz()); + EXPECT_GT(number_of_noise_frames, kNumIterations / 2); } // Sine curves are (very) stationary. They trigger the detector all // the time. Except for a few initial frames. -TEST(AutomaticGainController2SignalClassifier, SineTone) { - for (const auto rate : {8000, 16000, 32000, 48000}) { - test::SineGenerator gen(600.f, rate); - const int number_of_noise_frames = RunClassifier(gen, rate); - EXPECT_GE(number_of_noise_frames, kNumIterations - 5); - } +TEST_P(SignalClassifierParametrization, SineTone) { + test::SineGenerator gen(/*amplitude=*/test::kMaxS16, /*frequency_hz=*/600.0f, + sample_rate_hz()); + const int number_of_noise_frames = RunClassifier(gen, sample_rate_hz()); + EXPECT_GE(number_of_noise_frames, kNumIterations - 5); } // Pulses are transient if they are far enough apart. They shouldn't // trigger the noise detector. -TEST(AutomaticGainController2SignalClassifier, PulseTone) { - for (const auto rate : {8000, 16000, 32000, 48000}) { - test::PulseGenerator gen(30.f, rate); - const int number_of_noise_frames = RunClassifier(gen, rate); - EXPECT_EQ(number_of_noise_frames, 0); - } +TEST_P(SignalClassifierParametrization, PulseTone) { + test::PulseGenerator gen(/*pulse_amplitude=*/test::kMaxS16, + /*no_pulse_amplitude=*/10.0f, /*frequency_hz=*/20.0f, + sample_rate_hz()); + const int number_of_noise_frames = RunClassifier(gen, sample_rate_hz()); + EXPECT_EQ(number_of_noise_frames, 0); } + +INSTANTIATE_TEST_SUITE_P(GainController2SignalClassifier, + SignalClassifierParametrization, + ::testing::Values(8000, 16000, 32000, 48000)); + +} // namespace } // namespace webrtc