From 4ed47d01901248a93c525a5f7c75015cf605cbff Mon Sep 17 00:00:00 2001 From: Alex Loiko Date: Wed, 4 Apr 2018 15:05:57 +0200 Subject: [PATCH] Noise level estimation for AGC2. We put back the old noise estimator from LevelController. We add a few new unit tests. We also re-arrange the code so that it fits with how it is used in AGC2. The differences are: 1. The NoiseLevelEstimator is now fully self-contained. 2. The NoiseLevelEstimator is responsible for calling SignalClassifier and computing the signal energy. Previously the signal type and energy were used in several places. It made sense to compute the values independently of the noise calculation. 3. Re-initialization doesn't have to be done by the caller. 4. The interface is AudioFrameView instead of AudioBuffer. # Bots are green, nothing should break internal stuff NOTRY=True Bug: webrtc:7494 Change-Id: I442bdbbeb3796eb2518e96000aec9dc5a039ae6d Reviewed-on: https://webrtc-review.googlesource.com/66380 Commit-Queue: Alex Loiko Reviewed-by: Sam Zackrisson Cr-Commit-Position: refs/heads/master@{#22738} --- modules/audio_processing/BUILD.gn | 1 + modules/audio_processing/agc2/BUILD.gn | 49 ++++- modules/audio_processing/agc2/adaptive_agc.cc | 3 +- .../agc2/agc2_testing_common.h | 45 +++++ .../audio_processing/agc2/biquad_filter.cc | 35 ++++ modules/audio_processing/agc2/biquad_filter.h | 56 ++++++ modules/audio_processing/agc2/down_sampler.cc | 98 ++++++++++ modules/audio_processing/agc2/down_sampler.h | 40 +++++ .../agc2/noise_level_estimator.cc | 97 +++++++++- .../agc2/noise_level_estimator.h | 17 +- .../agc2/noise_level_estimator_unittest.cc | 83 +++++++++ .../agc2/noise_spectrum_estimator.cc | 68 +++++++ .../agc2/noise_spectrum_estimator.h | 40 +++++ .../agc2/signal_classifier.cc | 167 ++++++++++++++++++ .../audio_processing/agc2/signal_classifier.h | 67 +++++++ .../agc2/signal_classifier_unittest.cc | 82 +++++++++ 16 files changed, 939 insertions(+), 9 deletions(-) create mode 100644 modules/audio_processing/agc2/biquad_filter.cc create mode 100644 modules/audio_processing/agc2/biquad_filter.h create mode 100644 modules/audio_processing/agc2/down_sampler.cc create mode 100644 modules/audio_processing/agc2/down_sampler.h create mode 100644 modules/audio_processing/agc2/noise_level_estimator_unittest.cc create mode 100644 modules/audio_processing/agc2/noise_spectrum_estimator.cc create mode 100644 modules/audio_processing/agc2/noise_spectrum_estimator.h create mode 100644 modules/audio_processing/agc2/signal_classifier.cc create mode 100644 modules/audio_processing/agc2/signal_classifier.h create mode 100644 modules/audio_processing/agc2/signal_classifier_unittest.cc diff --git a/modules/audio_processing/BUILD.gn b/modules/audio_processing/BUILD.gn index b36e76b1cc..b0b425329d 100644 --- a/modules/audio_processing/BUILD.gn +++ b/modules/audio_processing/BUILD.gn @@ -550,6 +550,7 @@ if (rtc_include_tests) { "aec_dump:mock_aec_dump_unittests", "agc2:adaptive_digital_unittests", "agc2:fixed_digital_unittests", + "agc2:noise_estimator_unittests", "test/conversational_speech:unittest", "vad:vad_unittests", "//testing/gtest", diff --git a/modules/audio_processing/agc2/BUILD.gn b/modules/audio_processing/agc2/BUILD.gn index ca71b9278f..f9051257e4 100644 --- a/modules/audio_processing/agc2/BUILD.gn +++ b/modules/audio_processing/agc2/BUILD.gn @@ -23,8 +23,6 @@ rtc_source_set("adaptive_digital") { "adaptive_digital_gain_applier.h", "adaptive_mode_level_estimator.cc", "adaptive_mode_level_estimator.h", - "noise_level_estimator.cc", - "noise_level_estimator.h", "saturation_protector.cc", "saturation_protector.h", ] @@ -33,6 +31,7 @@ rtc_source_set("adaptive_digital") { deps = [ ":common", + ":noise_level_estimator", "..:aec_core", "..:apm_logging", "..:audio_frame_view", @@ -83,6 +82,32 @@ rtc_source_set("common") { ] } +rtc_source_set("noise_level_estimator") { + sources = [ + "biquad_filter.cc", + "biquad_filter.h", + "down_sampler.cc", + "down_sampler.h", + "noise_level_estimator.cc", + "noise_level_estimator.h", + "noise_spectrum_estimator.cc", + "noise_spectrum_estimator.h", + "signal_classifier.cc", + "signal_classifier.h", + ] + deps = [ + "..:aec_core", + "..:apm_logging", + "..:audio_frame_view", + "../../../api:array_view", + "../../../common_audio", + "../../../rtc_base:checks", + "../../../rtc_base:macromagic", + ] + + configs += [ "..:apm_debug_dump" ] +} + rtc_source_set("test_utils") { testonly = true visibility = [ ":*" ] @@ -151,3 +176,23 @@ rtc_source_set("adaptive_digital_unittests") { "../vad:vad_with_level", ] } + +rtc_source_set("noise_estimator_unittests") { + testonly = true + configs += [ "..:apm_debug_dump" ] + + sources = [ + "noise_level_estimator_unittest.cc", + "signal_classifier_unittest.cc", + ] + deps = [ + ":noise_level_estimator", + ":test_utils", + "..:apm_logging", + "..:audio_frame_view", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../../../rtc_base:rtc_base_tests_utils", + ] +} diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc index dff38fb44b..0de27a41a6 100644 --- a/modules/audio_processing/agc2/adaptive_agc.cc +++ b/modules/audio_processing/agc2/adaptive_agc.cc @@ -22,7 +22,8 @@ namespace webrtc { AdaptiveAgc::AdaptiveAgc(ApmDataDumper* apm_data_dumper) : speech_level_estimator_(apm_data_dumper), gain_applier_(apm_data_dumper), - apm_data_dumper_(apm_data_dumper) { + apm_data_dumper_(apm_data_dumper), + noise_level_estimator_(apm_data_dumper) { RTC_DCHECK(apm_data_dumper); } diff --git a/modules/audio_processing/agc2/agc2_testing_common.h b/modules/audio_processing/agc2/agc2_testing_common.h index a176282ede..8c4f400a98 100644 --- a/modules/audio_processing/agc2/agc2_testing_common.h +++ b/modules/audio_processing/agc2/agc2_testing_common.h @@ -11,9 +11,13 @@ #ifndef MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_ #define MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_ +#include + +#include #include #include "rtc_base/basictypes.h" +#include "rtc_base/checks.h" namespace webrtc { @@ -26,8 +30,49 @@ constexpr float kDecayMs = 500.f; constexpr float kLimiterMaxInputLevelDbFs = 1.f; constexpr float kLimiterKneeSmoothnessDb = 1.f; constexpr float kLimiterCompressionRatio = 5.f; +constexpr float kPi = 3.1415926536f; std::vector LinSpace(const double l, const double r, size_t num_points); + +class SineGenerator { + public: + SineGenerator(float frequency, int rate) + : frequency_(frequency), rate_(rate) {} + float operator()() { + x_radians_ += frequency_ / rate_ * 2 * kPi; + if (x_radians_ > 2 * kPi) { + x_radians_ -= 2 * kPi; + } + return 1000.f * sinf(x_radians_); + } + + private: + float frequency_; + int rate_; + float x_radians_ = 0.f; +}; + +class PulseGenerator { + public: + PulseGenerator(float frequency, int rate) + : samples_period_( + static_cast(static_cast(rate) / frequency)) { + RTC_DCHECK_GT(rate, frequency); + } + float operator()() { + sample_counter_++; + if (sample_counter_ >= samples_period_) { + sample_counter_ -= samples_period_; + } + return static_cast( + sample_counter_ == 0 ? std::numeric_limits::max() : 10.f); + } + + private: + int samples_period_; + int sample_counter_ = 0; +}; + } // namespace test } // namespace webrtc diff --git a/modules/audio_processing/agc2/biquad_filter.cc b/modules/audio_processing/agc2/biquad_filter.cc new file mode 100644 index 0000000000..c15c6449ad --- /dev/null +++ b/modules/audio_processing/agc2/biquad_filter.cc @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/biquad_filter.h" + +namespace webrtc { + +// This method applies a biquad filter to an input signal x to produce an +// output signal y. The biquad coefficients are specified at the construction +// of the object. +void BiQuadFilter::Process(rtc::ArrayView x, + rtc::ArrayView y) { + for (size_t k = 0; k < x.size(); ++k) { + // Use temporary variable for x[k] to allow in-place function call + // (that x and y refer to the same array). + const float tmp = x[k]; + y[k] = coefficients_.b[0] * tmp + coefficients_.b[1] * biquad_state_.b[0] + + coefficients_.b[2] * biquad_state_.b[1] - + coefficients_.a[0] * biquad_state_.a[0] - + coefficients_.a[1] * biquad_state_.a[1]; + biquad_state_.b[1] = biquad_state_.b[0]; + biquad_state_.b[0] = tmp; + biquad_state_.a[1] = biquad_state_.a[0]; + biquad_state_.a[0] = y[k]; + } +} + +} // namespace webrtc diff --git a/modules/audio_processing/agc2/biquad_filter.h b/modules/audio_processing/agc2/biquad_filter.h new file mode 100644 index 0000000000..4fd5e2e392 --- /dev/null +++ b/modules/audio_processing/agc2/biquad_filter.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_ + +#include "api/array_view.h" +#include "rtc_base/arraysize.h" +#include "rtc_base/constructormagic.h" + +namespace webrtc { + +class BiQuadFilter { + public: + struct BiQuadCoefficients { + float b[3]; + float a[2]; + }; + + BiQuadFilter() = default; + + void Initialize(const BiQuadCoefficients& coefficients) { + coefficients_ = coefficients; + } + + // Produces a filtered output y of the input x. Both x and y need to + // have the same length. + void Process(rtc::ArrayView x, rtc::ArrayView y); + + private: + struct BiQuadState { + BiQuadState() { + std::fill(b, b + arraysize(b), 0.f); + std::fill(a, a + arraysize(a), 0.f); + } + + float b[2]; + float a[2]; + }; + + BiQuadState biquad_state_; + BiQuadCoefficients coefficients_; + + RTC_DISALLOW_COPY_AND_ASSIGN(BiQuadFilter); +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_ diff --git a/modules/audio_processing/agc2/down_sampler.cc b/modules/audio_processing/agc2/down_sampler.cc new file mode 100644 index 0000000000..50486e0a36 --- /dev/null +++ b/modules/audio_processing/agc2/down_sampler.cc @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/down_sampler.h" + +#include +#include + +#include "modules/audio_processing/agc2/biquad_filter.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace { + +constexpr int kChunkSizeMs = 10; +constexpr int kSampleRate8kHz = 8000; +constexpr int kSampleRate16kHz = 16000; +constexpr int kSampleRate32kHz = 32000; +constexpr int kSampleRate48kHz = 48000; + +// Bandlimiter coefficients computed based on that only +// the first 40 bins of the spectrum for the downsampled +// signal are used. +// [B,A] = butter(2,(41/64*4000)/8000) +const BiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients_16kHz = { + {0.1455f, 0.2911f, 0.1455f}, + {-0.6698f, 0.2520f}}; + +// [B,A] = butter(2,(41/64*4000)/16000) +const BiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients_32kHz = { + {0.0462f, 0.0924f, 0.0462f}, + {-1.3066f, 0.4915f}}; + +// [B,A] = butter(2,(41/64*4000)/24000) +const BiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients_48kHz = { + {0.0226f, 0.0452f, 0.0226f}, + {-1.5320f, 0.6224f}}; + +} // namespace + +DownSampler::DownSampler(ApmDataDumper* data_dumper) + : data_dumper_(data_dumper) { + Initialize(48000); +} +void DownSampler::Initialize(int sample_rate_hz) { + RTC_DCHECK( + sample_rate_hz == kSampleRate8kHz || sample_rate_hz == kSampleRate16kHz || + sample_rate_hz == kSampleRate32kHz || sample_rate_hz == kSampleRate48kHz); + + sample_rate_hz_ = sample_rate_hz; + down_sampling_factor_ = rtc::CheckedDivExact(sample_rate_hz_, 8000); + + /// Note that the down sampling filter is not used if the sample rate is 8 + /// kHz. + if (sample_rate_hz_ == kSampleRate16kHz) { + low_pass_filter_.Initialize(kLowPassFilterCoefficients_16kHz); + } else if (sample_rate_hz_ == kSampleRate32kHz) { + low_pass_filter_.Initialize(kLowPassFilterCoefficients_32kHz); + } else if (sample_rate_hz_ == kSampleRate48kHz) { + low_pass_filter_.Initialize(kLowPassFilterCoefficients_48kHz); + } +} + +void DownSampler::DownSample(rtc::ArrayView in, + rtc::ArrayView out) { + data_dumper_->DumpWav("lc_down_sampler_input", in, sample_rate_hz_, 1); + RTC_DCHECK_EQ(sample_rate_hz_ * kChunkSizeMs / 1000, in.size()); + RTC_DCHECK_EQ(kSampleRate8kHz * kChunkSizeMs / 1000, out.size()); + const size_t kMaxNumFrames = kSampleRate48kHz * kChunkSizeMs / 1000; + float x[kMaxNumFrames]; + + // Band-limit the signal to 4 kHz. + if (sample_rate_hz_ != kSampleRate8kHz) { + low_pass_filter_.Process(in, rtc::ArrayView(x, in.size())); + + // Downsample the signal. + size_t k = 0; + for (size_t j = 0; j < out.size(); ++j) { + RTC_DCHECK_GT(kMaxNumFrames, k); + out[j] = x[k]; + k += down_sampling_factor_; + } + } else { + std::copy(in.data(), in.data() + in.size(), out.data()); + } + + data_dumper_->DumpWav("lc_down_sampler_output", out, kSampleRate8kHz, 1); +} + +} // namespace webrtc diff --git a/modules/audio_processing/agc2/down_sampler.h b/modules/audio_processing/agc2/down_sampler.h new file mode 100644 index 0000000000..a609ea8e0c --- /dev/null +++ b/modules/audio_processing/agc2/down_sampler.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_DOWN_SAMPLER_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_DOWN_SAMPLER_H_ + +#include "api/array_view.h" +#include "modules/audio_processing/agc2/biquad_filter.h" +#include "rtc_base/constructormagic.h" + +namespace webrtc { + +class ApmDataDumper; + +class DownSampler { + public: + explicit DownSampler(ApmDataDumper* data_dumper); + void Initialize(int sample_rate_hz); + + void DownSample(rtc::ArrayView in, rtc::ArrayView out); + + private: + ApmDataDumper* data_dumper_; + int sample_rate_hz_; + int down_sampling_factor_; + BiQuadFilter low_pass_filter_; + + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(DownSampler); +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_DOWN_SAMPLER_H_ diff --git a/modules/audio_processing/agc2/noise_level_estimator.cc b/modules/audio_processing/agc2/noise_level_estimator.cc index ede431c799..d9aaf1f9bd 100644 --- a/modules/audio_processing/agc2/noise_level_estimator.cc +++ b/modules/audio_processing/agc2/noise_level_estimator.cc @@ -10,11 +10,102 @@ #include "modules/audio_processing/agc2/noise_level_estimator.h" +#include + +#include +#include + +#include "common_audio/include/audio_util.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" + namespace webrtc { -float NoiseLevelEstimator::Analyze(AudioFrameView frame) { - // TODO(webrtc:7494): This is a stub. Add implementation. - return -50.f; +namespace { +constexpr int kFramesPerSecond = 100; + +float FrameEnergy(const AudioFrameView& audio) { + float energy = 0.f; + for (size_t k = 0; k < audio.num_channels(); ++k) { + float channel_energy = + std::accumulate(audio.channel(k).begin(), audio.channel(k).end(), 0.f, + [](float a, float b) -> float { return a + b * b; }); + energy = std::max(channel_energy, energy); + } + return energy; +} + +float EnergyToDbfs(float signal_energy, size_t num_samples) { + const float rms = std::sqrt(signal_energy / num_samples); + return FloatS16ToDbfs(rms); +} +} // namespace + +NoiseLevelEstimator::NoiseLevelEstimator(ApmDataDumper* data_dumper) + : signal_classifier_(data_dumper) { + Initialize(48000); +} + +NoiseLevelEstimator::~NoiseLevelEstimator() {} + +void NoiseLevelEstimator::Initialize(int sample_rate_hz) { + sample_rate_hz_ = sample_rate_hz; + noise_energy_ = 1.f; + first_update_ = true; + min_noise_energy_ = sample_rate_hz * 2.f * 2.f / kFramesPerSecond; + noise_energy_hold_counter_ = 0; + signal_classifier_.Initialize(sample_rate_hz); +} + +float NoiseLevelEstimator::Analyze(const AudioFrameView& frame) { + const int rate = + static_cast(frame.samples_per_channel() * kFramesPerSecond); + if (rate != sample_rate_hz_) { + Initialize(rate); + } + const float frame_energy = FrameEnergy(frame); + if (frame_energy <= 0.f) { + RTC_DCHECK_GE(frame_energy, 0.f); + return EnergyToDbfs(noise_energy_, frame.samples_per_channel()); + } + + if (first_update_) { + // Initialize the noise energy to the frame energy. + first_update_ = false; + return EnergyToDbfs( + noise_energy_ = std::max(frame_energy, min_noise_energy_), + frame.samples_per_channel()); + } + + const SignalClassifier::SignalType signal_type = + signal_classifier_.Analyze(frame.channel(0)); + + // Update the noise estimate in a minimum statistics-type manner. + if (signal_type == SignalClassifier::SignalType::kStationary) { + if (frame_energy > noise_energy_) { + // Leak the estimate upwards towards the frame energy if no recent + // downward update. + noise_energy_hold_counter_ = std::max(noise_energy_hold_counter_ - 1, 0); + + if (noise_energy_hold_counter_ == 0) { + noise_energy_ = std::min(noise_energy_ * 1.01f, frame_energy); + } + } else { + // Update smoothly downwards with a limited maximum update magnitude. + noise_energy_ = + std::max(noise_energy_ * 0.9f, + noise_energy_ + 0.05f * (frame_energy - noise_energy_)); + noise_energy_hold_counter_ = 1000; + } + } else { + // For a non-stationary signal, leak the estimate downwards in order to + // avoid estimate locking due to incorrect signal classification. + noise_energy_ = noise_energy_ * 0.99f; + } + + // Ensure a minimum of the estimate. + return EnergyToDbfs( + noise_energy_ = std::max(noise_energy_, min_noise_energy_), + frame.samples_per_channel()); } } // namespace webrtc diff --git a/modules/audio_processing/agc2/noise_level_estimator.h b/modules/audio_processing/agc2/noise_level_estimator.h index f9e4abc8f5..24067a1665 100644 --- a/modules/audio_processing/agc2/noise_level_estimator.h +++ b/modules/audio_processing/agc2/noise_level_estimator.h @@ -11,19 +11,30 @@ #ifndef MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_ #define MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_ +#include "modules/audio_processing/agc2/signal_classifier.h" #include "modules/audio_processing/include/audio_frame_view.h" #include "rtc_base/constructormagic.h" namespace webrtc { +class ApmDataDumper; class NoiseLevelEstimator { public: - NoiseLevelEstimator() {} - + NoiseLevelEstimator(ApmDataDumper* data_dumper); + ~NoiseLevelEstimator(); // Returns the estimated noise level in dBFS. - float Analyze(AudioFrameView frame); + float Analyze(const AudioFrameView& frame); private: + void Initialize(int sample_rate_hz); + + int sample_rate_hz_; + float min_noise_energy_; + bool first_update_; + float noise_energy_; + int noise_energy_hold_counter_; + SignalClassifier signal_classifier_; + RTC_DISALLOW_COPY_AND_ASSIGN(NoiseLevelEstimator); }; diff --git a/modules/audio_processing/agc2/noise_level_estimator_unittest.cc b/modules/audio_processing/agc2/noise_level_estimator_unittest.cc new file mode 100644 index 0000000000..c4fd33b0a0 --- /dev/null +++ b/modules/audio_processing/agc2/noise_level_estimator_unittest.cc @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/noise_level_estimator.h" + +#include +#include +#include + +#include "modules/audio_processing/agc2/agc2_testing_common.h" +#include "modules/audio_processing/agc2/vector_float_frame.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/gunit.h" +#include "rtc_base/random.h" + +namespace webrtc { +namespace { +Random rand_gen(42); +ApmDataDumper data_dumper(0); +constexpr int kNumIterations = 200; +constexpr int kFramesPerSecond = 100; + +// Runs the noise estimator on audio generated by 'sample_generator' +// for kNumIterations. Returns the last noise level estimate. +float RunEstimator(std::function sample_generator, int rate) { + NoiseLevelEstimator estimator(&data_dumper); + const size_t samples_per_channel = + rtc::CheckedDivExact(rate, kFramesPerSecond); + VectorFloatFrame signal(1, static_cast(samples_per_channel), 0.f); + + for (int i = 0; i < kNumIterations; ++i) { + AudioFrameView frame_view = signal.float_frame_view(); + for (size_t j = 0; j < samples_per_channel; ++j) { + frame_view.channel(0)[j] = sample_generator(); + } + estimator.Analyze(frame_view); + } + return estimator.Analyze(signal.float_frame_view()); +} + +float WhiteNoiseGenerator() { + return static_cast(rand_gen.Rand(std::numeric_limits::min(), + std::numeric_limits::max())); +} +} // namespace + +// White random noise is stationary, but does not trigger the detector +// every frame due to the randomness. +TEST(AutomaticGainController2NoiseEstimator, RandomNoise) { + for (const auto rate : {8000, 16000, 32000, 48000}) { + const float noise_level = RunEstimator(WhiteNoiseGenerator, rate); + EXPECT_NEAR(noise_level, -5.f, 1.f); + } +} + +// Sine curves are (very) stationary. They trigger the detector all +// the time. Except for a few initial frames. +TEST(AutomaticGainController2NoiseEstimator, SineTone) { + for (const auto rate : {8000, 16000, 32000, 48000}) { + test::SineGenerator gen(600.f, rate); + const float noise_level = RunEstimator(gen, rate); + EXPECT_NEAR(noise_level, -33.f, 1.f); + } +} + +// Pulses are transient if they are far enough apart. They shouldn't +// trigger the noise detector. +TEST(AutomaticGainController2NoiseEstimator, PulseTone) { + for (const auto rate : {8000, 16000, 32000, 48000}) { + test::PulseGenerator gen(20.f, rate); + const int noise_level = RunEstimator(gen, rate); + EXPECT_NEAR(noise_level, -79.f, 1.f); + } +} + +} // namespace webrtc diff --git a/modules/audio_processing/agc2/noise_spectrum_estimator.cc b/modules/audio_processing/agc2/noise_spectrum_estimator.cc new file mode 100644 index 0000000000..9e08126e89 --- /dev/null +++ b/modules/audio_processing/agc2/noise_spectrum_estimator.cc @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/noise_spectrum_estimator.h" + +#include +#include + +#include "api/array_view.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/arraysize.h" + +namespace webrtc { +namespace { +constexpr float kMinNoisePower = 100.f; +} // namespace + +NoiseSpectrumEstimator::NoiseSpectrumEstimator(ApmDataDumper* data_dumper) + : data_dumper_(data_dumper) { + Initialize(); +} + +void NoiseSpectrumEstimator::Initialize() { + std::fill(noise_spectrum_, noise_spectrum_ + arraysize(noise_spectrum_), + kMinNoisePower); +} + +void NoiseSpectrumEstimator::Update(rtc::ArrayView spectrum, + bool first_update) { + RTC_DCHECK_EQ(65, spectrum.size()); + + if (first_update) { + // Initialize the noise spectral estimate with the signal spectrum. + std::copy(spectrum.data(), spectrum.data() + spectrum.size(), + noise_spectrum_); + } else { + // Smoothly update the noise spectral estimate towards the signal spectrum + // such that the magnitude of the updates are limited. + for (size_t k = 0; k < spectrum.size(); ++k) { + if (noise_spectrum_[k] < spectrum[k]) { + noise_spectrum_[k] = std::min( + 1.01f * noise_spectrum_[k], + noise_spectrum_[k] + 0.05f * (spectrum[k] - noise_spectrum_[k])); + } else { + noise_spectrum_[k] = std::max( + 0.99f * noise_spectrum_[k], + noise_spectrum_[k] + 0.05f * (spectrum[k] - noise_spectrum_[k])); + } + } + } + + // Ensure that the noise spectal estimate does not become too low. + for (auto& v : noise_spectrum_) { + v = std::max(v, kMinNoisePower); + } + + data_dumper_->DumpRaw("lc_noise_spectrum", 65, noise_spectrum_); + data_dumper_->DumpRaw("lc_signal_spectrum", spectrum); +} + +} // namespace webrtc diff --git a/modules/audio_processing/agc2/noise_spectrum_estimator.h b/modules/audio_processing/agc2/noise_spectrum_estimator.h new file mode 100644 index 0000000000..fd1cc13a3f --- /dev/null +++ b/modules/audio_processing/agc2/noise_spectrum_estimator.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_NOISE_SPECTRUM_ESTIMATOR_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_NOISE_SPECTRUM_ESTIMATOR_H_ + +#include "api/array_view.h" +#include "rtc_base/constructormagic.h" + +namespace webrtc { + +class ApmDataDumper; + +class NoiseSpectrumEstimator { + public: + explicit NoiseSpectrumEstimator(ApmDataDumper* data_dumper); + void Initialize(); + void Update(rtc::ArrayView spectrum, bool first_update); + + rtc::ArrayView GetNoiseSpectrum() const { + return rtc::ArrayView(noise_spectrum_); + } + + private: + ApmDataDumper* data_dumper_; + float noise_spectrum_[65]; + + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(NoiseSpectrumEstimator); +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_NOISE_SPECTRUM_ESTIMATOR_H_ diff --git a/modules/audio_processing/agc2/signal_classifier.cc b/modules/audio_processing/agc2/signal_classifier.cc new file mode 100644 index 0000000000..0ec34148b9 --- /dev/null +++ b/modules/audio_processing/agc2/signal_classifier.cc @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/signal_classifier.h" + +#include +#include +#include + +#include "api/array_view.h" +#include "modules/audio_processing/agc2/down_sampler.h" +#include "modules/audio_processing/agc2/noise_spectrum_estimator.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/constructormagic.h" + +namespace webrtc { +namespace { + +void RemoveDcLevel(rtc::ArrayView x) { + RTC_DCHECK_LT(0, x.size()); + float mean = std::accumulate(x.data(), x.data() + x.size(), 0.f); + mean /= x.size(); + + for (float& v : x) { + v -= mean; + } +} + +void PowerSpectrum(const OouraFft* ooura_fft, + rtc::ArrayView x, + rtc::ArrayView spectrum) { + RTC_DCHECK_EQ(65, spectrum.size()); + RTC_DCHECK_EQ(128, x.size()); + float X[128]; + std::copy(x.data(), x.data() + x.size(), X); + ooura_fft->Fft(X); + + float* X_p = X; + RTC_DCHECK_EQ(X_p, &X[0]); + spectrum[0] = (*X_p) * (*X_p); + ++X_p; + RTC_DCHECK_EQ(X_p, &X[1]); + spectrum[64] = (*X_p) * (*X_p); + for (int k = 1; k < 64; ++k) { + ++X_p; + RTC_DCHECK_EQ(X_p, &X[2 * k]); + spectrum[k] = (*X_p) * (*X_p); + ++X_p; + RTC_DCHECK_EQ(X_p, &X[2 * k + 1]); + spectrum[k] += (*X_p) * (*X_p); + } +} + +webrtc::SignalClassifier::SignalType ClassifySignal( + rtc::ArrayView signal_spectrum, + rtc::ArrayView noise_spectrum, + ApmDataDumper* data_dumper) { + int num_stationary_bands = 0; + int num_highly_nonstationary_bands = 0; + + // Detect stationary and highly nonstationary bands. + for (size_t k = 1; k < 40; k++) { + if (signal_spectrum[k] < 3 * noise_spectrum[k] && + signal_spectrum[k] * 3 > noise_spectrum[k]) { + ++num_stationary_bands; + } else if (signal_spectrum[k] > 9 * noise_spectrum[k]) { + ++num_highly_nonstationary_bands; + } + } + + data_dumper->DumpRaw("lc_num_stationary_bands", 1, &num_stationary_bands); + data_dumper->DumpRaw("lc_num_highly_nonstationary_bands", 1, + &num_highly_nonstationary_bands); + + // Use the detected number of bands to classify the overall signal + // stationarity. + if (num_stationary_bands > 15) { + return SignalClassifier::SignalType::kStationary; + } else { + return SignalClassifier::SignalType::kNonStationary; + } +} + +} // namespace + +SignalClassifier::FrameExtender::FrameExtender(size_t frame_size, + size_t extended_frame_size) + : x_old_(extended_frame_size - frame_size, 0.f) {} + +SignalClassifier::FrameExtender::~FrameExtender() = default; + +void SignalClassifier::FrameExtender::ExtendFrame( + rtc::ArrayView x, + rtc::ArrayView x_extended) { + RTC_DCHECK_EQ(x_old_.size() + x.size(), x_extended.size()); + std::copy(x_old_.data(), x_old_.data() + x_old_.size(), x_extended.data()); + std::copy(x.data(), x.data() + x.size(), x_extended.data() + x_old_.size()); + std::copy(x_extended.data() + x_extended.size() - x_old_.size(), + x_extended.data() + x_extended.size(), x_old_.data()); +} + +SignalClassifier::SignalClassifier(ApmDataDumper* data_dumper) + : data_dumper_(data_dumper), + down_sampler_(data_dumper_), + noise_spectrum_estimator_(data_dumper_) { + Initialize(48000); +} +SignalClassifier::~SignalClassifier() {} + +void SignalClassifier::Initialize(int sample_rate_hz) { + down_sampler_.Initialize(sample_rate_hz); + noise_spectrum_estimator_.Initialize(); + frame_extender_.reset(new FrameExtender(80, 128)); + sample_rate_hz_ = sample_rate_hz; + initialization_frames_left_ = 2; + consistent_classification_counter_ = 3; + last_signal_type_ = SignalClassifier::SignalType::kNonStationary; +} + +SignalClassifier::SignalType SignalClassifier::Analyze( + rtc::ArrayView signal) { + RTC_DCHECK_EQ(signal.size(), sample_rate_hz_ / 100); + + // Compute the signal power spectrum. + float downsampled_frame[80]; + down_sampler_.DownSample(signal, downsampled_frame); + float extended_frame[128]; + frame_extender_->ExtendFrame(downsampled_frame, extended_frame); + RemoveDcLevel(extended_frame); + float signal_spectrum[65]; + PowerSpectrum(&ooura_fft_, extended_frame, signal_spectrum); + + // Classify the signal based on the estimate of the noise spectrum and the + // signal spectrum estimate. + const SignalType signal_type = ClassifySignal( + signal_spectrum, noise_spectrum_estimator_.GetNoiseSpectrum(), + data_dumper_); + + // Update the noise spectrum based on the signal spectrum. + noise_spectrum_estimator_.Update(signal_spectrum, + initialization_frames_left_ > 0); + + // Update the number of frames until a reliable signal spectrum is achieved. + initialization_frames_left_ = std::max(0, initialization_frames_left_ - 1); + + if (last_signal_type_ == signal_type) { + consistent_classification_counter_ = + std::max(0, consistent_classification_counter_ - 1); + } else { + last_signal_type_ = signal_type; + consistent_classification_counter_ = 3; + } + + if (consistent_classification_counter_ > 0) { + return SignalClassifier::SignalType::kNonStationary; + } + return signal_type; +} + +} // namespace webrtc diff --git a/modules/audio_processing/agc2/signal_classifier.h b/modules/audio_processing/agc2/signal_classifier.h new file mode 100644 index 0000000000..23fe315977 --- /dev/null +++ b/modules/audio_processing/agc2/signal_classifier.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_SIGNAL_CLASSIFIER_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_SIGNAL_CLASSIFIER_H_ + +#include +#include + +#include "api/array_view.h" +#include "modules/audio_processing/agc2/down_sampler.h" +#include "modules/audio_processing/agc2/noise_spectrum_estimator.h" +#include "modules/audio_processing/utility/ooura_fft.h" +#include "rtc_base/constructormagic.h" + +namespace webrtc { + +class ApmDataDumper; +class AudioBuffer; + +class SignalClassifier { + public: + enum class SignalType { kNonStationary, kStationary }; + + explicit SignalClassifier(ApmDataDumper* data_dumper); + ~SignalClassifier(); + + void Initialize(int sample_rate_hz); + SignalType Analyze(rtc::ArrayView signal); + + private: + class FrameExtender { + public: + FrameExtender(size_t frame_size, size_t extended_frame_size); + ~FrameExtender(); + + void ExtendFrame(rtc::ArrayView x, + rtc::ArrayView x_extended); + + private: + std::vector x_old_; + + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(FrameExtender); + }; + + ApmDataDumper* const data_dumper_; + DownSampler down_sampler_; + std::unique_ptr frame_extender_; + NoiseSpectrumEstimator noise_spectrum_estimator_; + int sample_rate_hz_; + int initialization_frames_left_; + int consistent_classification_counter_; + SignalType last_signal_type_; + const OouraFft ooura_fft_; + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(SignalClassifier); +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_SIGNAL_CLASSIFIER_H_ diff --git a/modules/audio_processing/agc2/signal_classifier_unittest.cc b/modules/audio_processing/agc2/signal_classifier_unittest.cc new file mode 100644 index 0000000000..62171b32e6 --- /dev/null +++ b/modules/audio_processing/agc2/signal_classifier_unittest.cc @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/signal_classifier.h" + +#include +#include +#include + +#include "modules/audio_processing/agc2/agc2_testing_common.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/gunit.h" +#include "rtc_base/random.h" + +namespace webrtc { + +namespace { +Random rand_gen(42); +ApmDataDumper data_dumper(0); +constexpr int kNumIterations = 100; + +// Runs the signal classifier on audio generated by 'sample_generator' +// for kNumIterations. Returns the number of frames classified as noise. +int RunClassifier(std::function sample_generator, int rate) { + SignalClassifier classifier(&data_dumper); + std::array signal; + classifier.Initialize(rate); + const size_t samples_per_channel = rtc::CheckedDivExact(rate, 100); + int number_of_noise_frames = 0; + for (int i = 0; i < kNumIterations; ++i) { + for (size_t j = 0; j < samples_per_channel; ++j) { + signal[j] = sample_generator(); + } + number_of_noise_frames += + classifier.Analyze({&signal[0], samples_per_channel}) == + SignalClassifier::SignalType::kStationary; + } + return number_of_noise_frames; +} + +float WhiteNoiseGenerator() { + return static_cast(rand_gen.Rand(std::numeric_limits::min(), + std::numeric_limits::max())); +} +} // namespace + +// White random noise is stationary, but does not trigger the detector +// every frame due to the randomness. +TEST(AutomaticGainController2SignalClassifier, WhiteNoise) { + for (const auto rate : {8000, 16000, 32000, 48000}) { + const int number_of_noise_frames = RunClassifier(WhiteNoiseGenerator, rate); + EXPECT_GT(number_of_noise_frames, kNumIterations / 2); + } +} + +// Sine curves are (very) stationary. They trigger the detector all +// the time. Except for a few initial frames. +TEST(AutomaticGainController2SignalClassifier, SineTone) { + for (const auto rate : {8000, 16000, 32000, 48000}) { + test::SineGenerator gen(600.f, rate); + const int number_of_noise_frames = RunClassifier(gen, rate); + EXPECT_GE(number_of_noise_frames, kNumIterations - 5); + } +} + +// Pulses are transient if they are far enough apart. They shouldn't +// trigger the noise detector. +TEST(AutomaticGainController2SignalClassifier, PulseTone) { + for (const auto rate : {8000, 16000, 32000, 48000}) { + test::PulseGenerator gen(30.f, rate); + const int number_of_noise_frames = RunClassifier(gen, rate); + EXPECT_EQ(number_of_noise_frames, 0); + } +} +} // namespace webrtc