diff --git a/modules/audio_processing/BUILD.gn b/modules/audio_processing/BUILD.gn index b36e76b1cc..b0b425329d 100644 --- a/modules/audio_processing/BUILD.gn +++ b/modules/audio_processing/BUILD.gn @@ -550,6 +550,7 @@ if (rtc_include_tests) { "aec_dump:mock_aec_dump_unittests", "agc2:adaptive_digital_unittests", "agc2:fixed_digital_unittests", + "agc2:noise_estimator_unittests", "test/conversational_speech:unittest", "vad:vad_unittests", "//testing/gtest", diff --git a/modules/audio_processing/agc2/BUILD.gn b/modules/audio_processing/agc2/BUILD.gn index ca71b9278f..f9051257e4 100644 --- a/modules/audio_processing/agc2/BUILD.gn +++ b/modules/audio_processing/agc2/BUILD.gn @@ -23,8 +23,6 @@ rtc_source_set("adaptive_digital") { "adaptive_digital_gain_applier.h", "adaptive_mode_level_estimator.cc", "adaptive_mode_level_estimator.h", - "noise_level_estimator.cc", - "noise_level_estimator.h", "saturation_protector.cc", "saturation_protector.h", ] @@ -33,6 +31,7 @@ rtc_source_set("adaptive_digital") { deps = [ ":common", + ":noise_level_estimator", "..:aec_core", "..:apm_logging", "..:audio_frame_view", @@ -83,6 +82,32 @@ rtc_source_set("common") { ] } +rtc_source_set("noise_level_estimator") { + sources = [ + "biquad_filter.cc", + "biquad_filter.h", + "down_sampler.cc", + "down_sampler.h", + "noise_level_estimator.cc", + "noise_level_estimator.h", + "noise_spectrum_estimator.cc", + "noise_spectrum_estimator.h", + "signal_classifier.cc", + "signal_classifier.h", + ] + deps = [ + "..:aec_core", + "..:apm_logging", + "..:audio_frame_view", + "../../../api:array_view", + "../../../common_audio", + "../../../rtc_base:checks", + "../../../rtc_base:macromagic", + ] + + configs += [ "..:apm_debug_dump" ] +} + rtc_source_set("test_utils") { testonly = true visibility = [ ":*" ] @@ -151,3 +176,23 @@ rtc_source_set("adaptive_digital_unittests") { "../vad:vad_with_level", ] } + +rtc_source_set("noise_estimator_unittests") { + testonly = true + configs += [ "..:apm_debug_dump" ] + + sources = [ + "noise_level_estimator_unittest.cc", + "signal_classifier_unittest.cc", + ] + deps = [ + ":noise_level_estimator", + ":test_utils", + "..:apm_logging", + "..:audio_frame_view", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../../../rtc_base:rtc_base_tests_utils", + ] +} diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc index dff38fb44b..0de27a41a6 100644 --- a/modules/audio_processing/agc2/adaptive_agc.cc +++ b/modules/audio_processing/agc2/adaptive_agc.cc @@ -22,7 +22,8 @@ namespace webrtc { AdaptiveAgc::AdaptiveAgc(ApmDataDumper* apm_data_dumper) : speech_level_estimator_(apm_data_dumper), gain_applier_(apm_data_dumper), - apm_data_dumper_(apm_data_dumper) { + apm_data_dumper_(apm_data_dumper), + noise_level_estimator_(apm_data_dumper) { RTC_DCHECK(apm_data_dumper); } diff --git a/modules/audio_processing/agc2/agc2_testing_common.h b/modules/audio_processing/agc2/agc2_testing_common.h index a176282ede..8c4f400a98 100644 --- a/modules/audio_processing/agc2/agc2_testing_common.h +++ b/modules/audio_processing/agc2/agc2_testing_common.h @@ -11,9 +11,13 @@ #ifndef MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_ #define MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_ +#include + +#include #include #include "rtc_base/basictypes.h" +#include "rtc_base/checks.h" namespace webrtc { @@ -26,8 +30,49 @@ constexpr float kDecayMs = 500.f; constexpr float kLimiterMaxInputLevelDbFs = 1.f; constexpr float kLimiterKneeSmoothnessDb = 1.f; constexpr float kLimiterCompressionRatio = 5.f; +constexpr float kPi = 3.1415926536f; std::vector LinSpace(const double l, const double r, size_t num_points); + +class SineGenerator { + public: + SineGenerator(float frequency, int rate) + : frequency_(frequency), rate_(rate) {} + float operator()() { + x_radians_ += frequency_ / rate_ * 2 * kPi; + if (x_radians_ > 2 * kPi) { + x_radians_ -= 2 * kPi; + } + return 1000.f * sinf(x_radians_); + } + + private: + float frequency_; + int rate_; + float x_radians_ = 0.f; +}; + +class PulseGenerator { + public: + PulseGenerator(float frequency, int rate) + : samples_period_( + static_cast(static_cast(rate) / frequency)) { + RTC_DCHECK_GT(rate, frequency); + } + float operator()() { + sample_counter_++; + if (sample_counter_ >= samples_period_) { + sample_counter_ -= samples_period_; + } + return static_cast( + sample_counter_ == 0 ? std::numeric_limits::max() : 10.f); + } + + private: + int samples_period_; + int sample_counter_ = 0; +}; + } // namespace test } // namespace webrtc diff --git a/modules/audio_processing/agc2/biquad_filter.cc b/modules/audio_processing/agc2/biquad_filter.cc new file mode 100644 index 0000000000..c15c6449ad --- /dev/null +++ b/modules/audio_processing/agc2/biquad_filter.cc @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/biquad_filter.h" + +namespace webrtc { + +// This method applies a biquad filter to an input signal x to produce an +// output signal y. The biquad coefficients are specified at the construction +// of the object. +void BiQuadFilter::Process(rtc::ArrayView x, + rtc::ArrayView y) { + for (size_t k = 0; k < x.size(); ++k) { + // Use temporary variable for x[k] to allow in-place function call + // (that x and y refer to the same array). + const float tmp = x[k]; + y[k] = coefficients_.b[0] * tmp + coefficients_.b[1] * biquad_state_.b[0] + + coefficients_.b[2] * biquad_state_.b[1] - + coefficients_.a[0] * biquad_state_.a[0] - + coefficients_.a[1] * biquad_state_.a[1]; + biquad_state_.b[1] = biquad_state_.b[0]; + biquad_state_.b[0] = tmp; + biquad_state_.a[1] = biquad_state_.a[0]; + biquad_state_.a[0] = y[k]; + } +} + +} // namespace webrtc diff --git a/modules/audio_processing/agc2/biquad_filter.h b/modules/audio_processing/agc2/biquad_filter.h new file mode 100644 index 0000000000..4fd5e2e392 --- /dev/null +++ b/modules/audio_processing/agc2/biquad_filter.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_ + +#include "api/array_view.h" +#include "rtc_base/arraysize.h" +#include "rtc_base/constructormagic.h" + +namespace webrtc { + +class BiQuadFilter { + public: + struct BiQuadCoefficients { + float b[3]; + float a[2]; + }; + + BiQuadFilter() = default; + + void Initialize(const BiQuadCoefficients& coefficients) { + coefficients_ = coefficients; + } + + // Produces a filtered output y of the input x. Both x and y need to + // have the same length. + void Process(rtc::ArrayView x, rtc::ArrayView y); + + private: + struct BiQuadState { + BiQuadState() { + std::fill(b, b + arraysize(b), 0.f); + std::fill(a, a + arraysize(a), 0.f); + } + + float b[2]; + float a[2]; + }; + + BiQuadState biquad_state_; + BiQuadCoefficients coefficients_; + + RTC_DISALLOW_COPY_AND_ASSIGN(BiQuadFilter); +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_ diff --git a/modules/audio_processing/agc2/down_sampler.cc b/modules/audio_processing/agc2/down_sampler.cc new file mode 100644 index 0000000000..50486e0a36 --- /dev/null +++ b/modules/audio_processing/agc2/down_sampler.cc @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/down_sampler.h" + +#include +#include + +#include "modules/audio_processing/agc2/biquad_filter.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace { + +constexpr int kChunkSizeMs = 10; +constexpr int kSampleRate8kHz = 8000; +constexpr int kSampleRate16kHz = 16000; +constexpr int kSampleRate32kHz = 32000; +constexpr int kSampleRate48kHz = 48000; + +// Bandlimiter coefficients computed based on that only +// the first 40 bins of the spectrum for the downsampled +// signal are used. +// [B,A] = butter(2,(41/64*4000)/8000) +const BiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients_16kHz = { + {0.1455f, 0.2911f, 0.1455f}, + {-0.6698f, 0.2520f}}; + +// [B,A] = butter(2,(41/64*4000)/16000) +const BiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients_32kHz = { + {0.0462f, 0.0924f, 0.0462f}, + {-1.3066f, 0.4915f}}; + +// [B,A] = butter(2,(41/64*4000)/24000) +const BiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients_48kHz = { + {0.0226f, 0.0452f, 0.0226f}, + {-1.5320f, 0.6224f}}; + +} // namespace + +DownSampler::DownSampler(ApmDataDumper* data_dumper) + : data_dumper_(data_dumper) { + Initialize(48000); +} +void DownSampler::Initialize(int sample_rate_hz) { + RTC_DCHECK( + sample_rate_hz == kSampleRate8kHz || sample_rate_hz == kSampleRate16kHz || + sample_rate_hz == kSampleRate32kHz || sample_rate_hz == kSampleRate48kHz); + + sample_rate_hz_ = sample_rate_hz; + down_sampling_factor_ = rtc::CheckedDivExact(sample_rate_hz_, 8000); + + /// Note that the down sampling filter is not used if the sample rate is 8 + /// kHz. + if (sample_rate_hz_ == kSampleRate16kHz) { + low_pass_filter_.Initialize(kLowPassFilterCoefficients_16kHz); + } else if (sample_rate_hz_ == kSampleRate32kHz) { + low_pass_filter_.Initialize(kLowPassFilterCoefficients_32kHz); + } else if (sample_rate_hz_ == kSampleRate48kHz) { + low_pass_filter_.Initialize(kLowPassFilterCoefficients_48kHz); + } +} + +void DownSampler::DownSample(rtc::ArrayView in, + rtc::ArrayView out) { + data_dumper_->DumpWav("lc_down_sampler_input", in, sample_rate_hz_, 1); + RTC_DCHECK_EQ(sample_rate_hz_ * kChunkSizeMs / 1000, in.size()); + RTC_DCHECK_EQ(kSampleRate8kHz * kChunkSizeMs / 1000, out.size()); + const size_t kMaxNumFrames = kSampleRate48kHz * kChunkSizeMs / 1000; + float x[kMaxNumFrames]; + + // Band-limit the signal to 4 kHz. + if (sample_rate_hz_ != kSampleRate8kHz) { + low_pass_filter_.Process(in, rtc::ArrayView(x, in.size())); + + // Downsample the signal. + size_t k = 0; + for (size_t j = 0; j < out.size(); ++j) { + RTC_DCHECK_GT(kMaxNumFrames, k); + out[j] = x[k]; + k += down_sampling_factor_; + } + } else { + std::copy(in.data(), in.data() + in.size(), out.data()); + } + + data_dumper_->DumpWav("lc_down_sampler_output", out, kSampleRate8kHz, 1); +} + +} // namespace webrtc diff --git a/modules/audio_processing/agc2/down_sampler.h b/modules/audio_processing/agc2/down_sampler.h new file mode 100644 index 0000000000..a609ea8e0c --- /dev/null +++ b/modules/audio_processing/agc2/down_sampler.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_DOWN_SAMPLER_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_DOWN_SAMPLER_H_ + +#include "api/array_view.h" +#include "modules/audio_processing/agc2/biquad_filter.h" +#include "rtc_base/constructormagic.h" + +namespace webrtc { + +class ApmDataDumper; + +class DownSampler { + public: + explicit DownSampler(ApmDataDumper* data_dumper); + void Initialize(int sample_rate_hz); + + void DownSample(rtc::ArrayView in, rtc::ArrayView out); + + private: + ApmDataDumper* data_dumper_; + int sample_rate_hz_; + int down_sampling_factor_; + BiQuadFilter low_pass_filter_; + + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(DownSampler); +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_DOWN_SAMPLER_H_ diff --git a/modules/audio_processing/agc2/noise_level_estimator.cc b/modules/audio_processing/agc2/noise_level_estimator.cc index ede431c799..d9aaf1f9bd 100644 --- a/modules/audio_processing/agc2/noise_level_estimator.cc +++ b/modules/audio_processing/agc2/noise_level_estimator.cc @@ -10,11 +10,102 @@ #include "modules/audio_processing/agc2/noise_level_estimator.h" +#include + +#include +#include + +#include "common_audio/include/audio_util.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" + namespace webrtc { -float NoiseLevelEstimator::Analyze(AudioFrameView frame) { - // TODO(webrtc:7494): This is a stub. Add implementation. - return -50.f; +namespace { +constexpr int kFramesPerSecond = 100; + +float FrameEnergy(const AudioFrameView& audio) { + float energy = 0.f; + for (size_t k = 0; k < audio.num_channels(); ++k) { + float channel_energy = + std::accumulate(audio.channel(k).begin(), audio.channel(k).end(), 0.f, + [](float a, float b) -> float { return a + b * b; }); + energy = std::max(channel_energy, energy); + } + return energy; +} + +float EnergyToDbfs(float signal_energy, size_t num_samples) { + const float rms = std::sqrt(signal_energy / num_samples); + return FloatS16ToDbfs(rms); +} +} // namespace + +NoiseLevelEstimator::NoiseLevelEstimator(ApmDataDumper* data_dumper) + : signal_classifier_(data_dumper) { + Initialize(48000); +} + +NoiseLevelEstimator::~NoiseLevelEstimator() {} + +void NoiseLevelEstimator::Initialize(int sample_rate_hz) { + sample_rate_hz_ = sample_rate_hz; + noise_energy_ = 1.f; + first_update_ = true; + min_noise_energy_ = sample_rate_hz * 2.f * 2.f / kFramesPerSecond; + noise_energy_hold_counter_ = 0; + signal_classifier_.Initialize(sample_rate_hz); +} + +float NoiseLevelEstimator::Analyze(const AudioFrameView& frame) { + const int rate = + static_cast(frame.samples_per_channel() * kFramesPerSecond); + if (rate != sample_rate_hz_) { + Initialize(rate); + } + const float frame_energy = FrameEnergy(frame); + if (frame_energy <= 0.f) { + RTC_DCHECK_GE(frame_energy, 0.f); + return EnergyToDbfs(noise_energy_, frame.samples_per_channel()); + } + + if (first_update_) { + // Initialize the noise energy to the frame energy. + first_update_ = false; + return EnergyToDbfs( + noise_energy_ = std::max(frame_energy, min_noise_energy_), + frame.samples_per_channel()); + } + + const SignalClassifier::SignalType signal_type = + signal_classifier_.Analyze(frame.channel(0)); + + // Update the noise estimate in a minimum statistics-type manner. + if (signal_type == SignalClassifier::SignalType::kStationary) { + if (frame_energy > noise_energy_) { + // Leak the estimate upwards towards the frame energy if no recent + // downward update. + noise_energy_hold_counter_ = std::max(noise_energy_hold_counter_ - 1, 0); + + if (noise_energy_hold_counter_ == 0) { + noise_energy_ = std::min(noise_energy_ * 1.01f, frame_energy); + } + } else { + // Update smoothly downwards with a limited maximum update magnitude. + noise_energy_ = + std::max(noise_energy_ * 0.9f, + noise_energy_ + 0.05f * (frame_energy - noise_energy_)); + noise_energy_hold_counter_ = 1000; + } + } else { + // For a non-stationary signal, leak the estimate downwards in order to + // avoid estimate locking due to incorrect signal classification. + noise_energy_ = noise_energy_ * 0.99f; + } + + // Ensure a minimum of the estimate. + return EnergyToDbfs( + noise_energy_ = std::max(noise_energy_, min_noise_energy_), + frame.samples_per_channel()); } } // namespace webrtc diff --git a/modules/audio_processing/agc2/noise_level_estimator.h b/modules/audio_processing/agc2/noise_level_estimator.h index f9e4abc8f5..24067a1665 100644 --- a/modules/audio_processing/agc2/noise_level_estimator.h +++ b/modules/audio_processing/agc2/noise_level_estimator.h @@ -11,19 +11,30 @@ #ifndef MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_ #define MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_ +#include "modules/audio_processing/agc2/signal_classifier.h" #include "modules/audio_processing/include/audio_frame_view.h" #include "rtc_base/constructormagic.h" namespace webrtc { +class ApmDataDumper; class NoiseLevelEstimator { public: - NoiseLevelEstimator() {} - + NoiseLevelEstimator(ApmDataDumper* data_dumper); + ~NoiseLevelEstimator(); // Returns the estimated noise level in dBFS. - float Analyze(AudioFrameView frame); + float Analyze(const AudioFrameView& frame); private: + void Initialize(int sample_rate_hz); + + int sample_rate_hz_; + float min_noise_energy_; + bool first_update_; + float noise_energy_; + int noise_energy_hold_counter_; + SignalClassifier signal_classifier_; + RTC_DISALLOW_COPY_AND_ASSIGN(NoiseLevelEstimator); }; diff --git a/modules/audio_processing/agc2/noise_level_estimator_unittest.cc b/modules/audio_processing/agc2/noise_level_estimator_unittest.cc new file mode 100644 index 0000000000..c4fd33b0a0 --- /dev/null +++ b/modules/audio_processing/agc2/noise_level_estimator_unittest.cc @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/noise_level_estimator.h" + +#include +#include +#include + +#include "modules/audio_processing/agc2/agc2_testing_common.h" +#include "modules/audio_processing/agc2/vector_float_frame.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/gunit.h" +#include "rtc_base/random.h" + +namespace webrtc { +namespace { +Random rand_gen(42); +ApmDataDumper data_dumper(0); +constexpr int kNumIterations = 200; +constexpr int kFramesPerSecond = 100; + +// Runs the noise estimator on audio generated by 'sample_generator' +// for kNumIterations. Returns the last noise level estimate. +float RunEstimator(std::function sample_generator, int rate) { + NoiseLevelEstimator estimator(&data_dumper); + const size_t samples_per_channel = + rtc::CheckedDivExact(rate, kFramesPerSecond); + VectorFloatFrame signal(1, static_cast(samples_per_channel), 0.f); + + for (int i = 0; i < kNumIterations; ++i) { + AudioFrameView frame_view = signal.float_frame_view(); + for (size_t j = 0; j < samples_per_channel; ++j) { + frame_view.channel(0)[j] = sample_generator(); + } + estimator.Analyze(frame_view); + } + return estimator.Analyze(signal.float_frame_view()); +} + +float WhiteNoiseGenerator() { + return static_cast(rand_gen.Rand(std::numeric_limits::min(), + std::numeric_limits::max())); +} +} // namespace + +// White random noise is stationary, but does not trigger the detector +// every frame due to the randomness. +TEST(AutomaticGainController2NoiseEstimator, RandomNoise) { + for (const auto rate : {8000, 16000, 32000, 48000}) { + const float noise_level = RunEstimator(WhiteNoiseGenerator, rate); + EXPECT_NEAR(noise_level, -5.f, 1.f); + } +} + +// Sine curves are (very) stationary. They trigger the detector all +// the time. Except for a few initial frames. +TEST(AutomaticGainController2NoiseEstimator, SineTone) { + for (const auto rate : {8000, 16000, 32000, 48000}) { + test::SineGenerator gen(600.f, rate); + const float noise_level = RunEstimator(gen, rate); + EXPECT_NEAR(noise_level, -33.f, 1.f); + } +} + +// Pulses are transient if they are far enough apart. They shouldn't +// trigger the noise detector. +TEST(AutomaticGainController2NoiseEstimator, PulseTone) { + for (const auto rate : {8000, 16000, 32000, 48000}) { + test::PulseGenerator gen(20.f, rate); + const int noise_level = RunEstimator(gen, rate); + EXPECT_NEAR(noise_level, -79.f, 1.f); + } +} + +} // namespace webrtc diff --git a/modules/audio_processing/agc2/noise_spectrum_estimator.cc b/modules/audio_processing/agc2/noise_spectrum_estimator.cc new file mode 100644 index 0000000000..9e08126e89 --- /dev/null +++ b/modules/audio_processing/agc2/noise_spectrum_estimator.cc @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/noise_spectrum_estimator.h" + +#include +#include + +#include "api/array_view.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/arraysize.h" + +namespace webrtc { +namespace { +constexpr float kMinNoisePower = 100.f; +} // namespace + +NoiseSpectrumEstimator::NoiseSpectrumEstimator(ApmDataDumper* data_dumper) + : data_dumper_(data_dumper) { + Initialize(); +} + +void NoiseSpectrumEstimator::Initialize() { + std::fill(noise_spectrum_, noise_spectrum_ + arraysize(noise_spectrum_), + kMinNoisePower); +} + +void NoiseSpectrumEstimator::Update(rtc::ArrayView spectrum, + bool first_update) { + RTC_DCHECK_EQ(65, spectrum.size()); + + if (first_update) { + // Initialize the noise spectral estimate with the signal spectrum. + std::copy(spectrum.data(), spectrum.data() + spectrum.size(), + noise_spectrum_); + } else { + // Smoothly update the noise spectral estimate towards the signal spectrum + // such that the magnitude of the updates are limited. + for (size_t k = 0; k < spectrum.size(); ++k) { + if (noise_spectrum_[k] < spectrum[k]) { + noise_spectrum_[k] = std::min( + 1.01f * noise_spectrum_[k], + noise_spectrum_[k] + 0.05f * (spectrum[k] - noise_spectrum_[k])); + } else { + noise_spectrum_[k] = std::max( + 0.99f * noise_spectrum_[k], + noise_spectrum_[k] + 0.05f * (spectrum[k] - noise_spectrum_[k])); + } + } + } + + // Ensure that the noise spectal estimate does not become too low. + for (auto& v : noise_spectrum_) { + v = std::max(v, kMinNoisePower); + } + + data_dumper_->DumpRaw("lc_noise_spectrum", 65, noise_spectrum_); + data_dumper_->DumpRaw("lc_signal_spectrum", spectrum); +} + +} // namespace webrtc diff --git a/modules/audio_processing/agc2/noise_spectrum_estimator.h b/modules/audio_processing/agc2/noise_spectrum_estimator.h new file mode 100644 index 0000000000..fd1cc13a3f --- /dev/null +++ b/modules/audio_processing/agc2/noise_spectrum_estimator.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_NOISE_SPECTRUM_ESTIMATOR_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_NOISE_SPECTRUM_ESTIMATOR_H_ + +#include "api/array_view.h" +#include "rtc_base/constructormagic.h" + +namespace webrtc { + +class ApmDataDumper; + +class NoiseSpectrumEstimator { + public: + explicit NoiseSpectrumEstimator(ApmDataDumper* data_dumper); + void Initialize(); + void Update(rtc::ArrayView spectrum, bool first_update); + + rtc::ArrayView GetNoiseSpectrum() const { + return rtc::ArrayView(noise_spectrum_); + } + + private: + ApmDataDumper* data_dumper_; + float noise_spectrum_[65]; + + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(NoiseSpectrumEstimator); +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_NOISE_SPECTRUM_ESTIMATOR_H_ diff --git a/modules/audio_processing/agc2/signal_classifier.cc b/modules/audio_processing/agc2/signal_classifier.cc new file mode 100644 index 0000000000..0ec34148b9 --- /dev/null +++ b/modules/audio_processing/agc2/signal_classifier.cc @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/signal_classifier.h" + +#include +#include +#include + +#include "api/array_view.h" +#include "modules/audio_processing/agc2/down_sampler.h" +#include "modules/audio_processing/agc2/noise_spectrum_estimator.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/constructormagic.h" + +namespace webrtc { +namespace { + +void RemoveDcLevel(rtc::ArrayView x) { + RTC_DCHECK_LT(0, x.size()); + float mean = std::accumulate(x.data(), x.data() + x.size(), 0.f); + mean /= x.size(); + + for (float& v : x) { + v -= mean; + } +} + +void PowerSpectrum(const OouraFft* ooura_fft, + rtc::ArrayView x, + rtc::ArrayView spectrum) { + RTC_DCHECK_EQ(65, spectrum.size()); + RTC_DCHECK_EQ(128, x.size()); + float X[128]; + std::copy(x.data(), x.data() + x.size(), X); + ooura_fft->Fft(X); + + float* X_p = X; + RTC_DCHECK_EQ(X_p, &X[0]); + spectrum[0] = (*X_p) * (*X_p); + ++X_p; + RTC_DCHECK_EQ(X_p, &X[1]); + spectrum[64] = (*X_p) * (*X_p); + for (int k = 1; k < 64; ++k) { + ++X_p; + RTC_DCHECK_EQ(X_p, &X[2 * k]); + spectrum[k] = (*X_p) * (*X_p); + ++X_p; + RTC_DCHECK_EQ(X_p, &X[2 * k + 1]); + spectrum[k] += (*X_p) * (*X_p); + } +} + +webrtc::SignalClassifier::SignalType ClassifySignal( + rtc::ArrayView signal_spectrum, + rtc::ArrayView noise_spectrum, + ApmDataDumper* data_dumper) { + int num_stationary_bands = 0; + int num_highly_nonstationary_bands = 0; + + // Detect stationary and highly nonstationary bands. + for (size_t k = 1; k < 40; k++) { + if (signal_spectrum[k] < 3 * noise_spectrum[k] && + signal_spectrum[k] * 3 > noise_spectrum[k]) { + ++num_stationary_bands; + } else if (signal_spectrum[k] > 9 * noise_spectrum[k]) { + ++num_highly_nonstationary_bands; + } + } + + data_dumper->DumpRaw("lc_num_stationary_bands", 1, &num_stationary_bands); + data_dumper->DumpRaw("lc_num_highly_nonstationary_bands", 1, + &num_highly_nonstationary_bands); + + // Use the detected number of bands to classify the overall signal + // stationarity. + if (num_stationary_bands > 15) { + return SignalClassifier::SignalType::kStationary; + } else { + return SignalClassifier::SignalType::kNonStationary; + } +} + +} // namespace + +SignalClassifier::FrameExtender::FrameExtender(size_t frame_size, + size_t extended_frame_size) + : x_old_(extended_frame_size - frame_size, 0.f) {} + +SignalClassifier::FrameExtender::~FrameExtender() = default; + +void SignalClassifier::FrameExtender::ExtendFrame( + rtc::ArrayView x, + rtc::ArrayView x_extended) { + RTC_DCHECK_EQ(x_old_.size() + x.size(), x_extended.size()); + std::copy(x_old_.data(), x_old_.data() + x_old_.size(), x_extended.data()); + std::copy(x.data(), x.data() + x.size(), x_extended.data() + x_old_.size()); + std::copy(x_extended.data() + x_extended.size() - x_old_.size(), + x_extended.data() + x_extended.size(), x_old_.data()); +} + +SignalClassifier::SignalClassifier(ApmDataDumper* data_dumper) + : data_dumper_(data_dumper), + down_sampler_(data_dumper_), + noise_spectrum_estimator_(data_dumper_) { + Initialize(48000); +} +SignalClassifier::~SignalClassifier() {} + +void SignalClassifier::Initialize(int sample_rate_hz) { + down_sampler_.Initialize(sample_rate_hz); + noise_spectrum_estimator_.Initialize(); + frame_extender_.reset(new FrameExtender(80, 128)); + sample_rate_hz_ = sample_rate_hz; + initialization_frames_left_ = 2; + consistent_classification_counter_ = 3; + last_signal_type_ = SignalClassifier::SignalType::kNonStationary; +} + +SignalClassifier::SignalType SignalClassifier::Analyze( + rtc::ArrayView signal) { + RTC_DCHECK_EQ(signal.size(), sample_rate_hz_ / 100); + + // Compute the signal power spectrum. + float downsampled_frame[80]; + down_sampler_.DownSample(signal, downsampled_frame); + float extended_frame[128]; + frame_extender_->ExtendFrame(downsampled_frame, extended_frame); + RemoveDcLevel(extended_frame); + float signal_spectrum[65]; + PowerSpectrum(&ooura_fft_, extended_frame, signal_spectrum); + + // Classify the signal based on the estimate of the noise spectrum and the + // signal spectrum estimate. + const SignalType signal_type = ClassifySignal( + signal_spectrum, noise_spectrum_estimator_.GetNoiseSpectrum(), + data_dumper_); + + // Update the noise spectrum based on the signal spectrum. + noise_spectrum_estimator_.Update(signal_spectrum, + initialization_frames_left_ > 0); + + // Update the number of frames until a reliable signal spectrum is achieved. + initialization_frames_left_ = std::max(0, initialization_frames_left_ - 1); + + if (last_signal_type_ == signal_type) { + consistent_classification_counter_ = + std::max(0, consistent_classification_counter_ - 1); + } else { + last_signal_type_ = signal_type; + consistent_classification_counter_ = 3; + } + + if (consistent_classification_counter_ > 0) { + return SignalClassifier::SignalType::kNonStationary; + } + return signal_type; +} + +} // namespace webrtc diff --git a/modules/audio_processing/agc2/signal_classifier.h b/modules/audio_processing/agc2/signal_classifier.h new file mode 100644 index 0000000000..23fe315977 --- /dev/null +++ b/modules/audio_processing/agc2/signal_classifier.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_SIGNAL_CLASSIFIER_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_SIGNAL_CLASSIFIER_H_ + +#include +#include + +#include "api/array_view.h" +#include "modules/audio_processing/agc2/down_sampler.h" +#include "modules/audio_processing/agc2/noise_spectrum_estimator.h" +#include "modules/audio_processing/utility/ooura_fft.h" +#include "rtc_base/constructormagic.h" + +namespace webrtc { + +class ApmDataDumper; +class AudioBuffer; + +class SignalClassifier { + public: + enum class SignalType { kNonStationary, kStationary }; + + explicit SignalClassifier(ApmDataDumper* data_dumper); + ~SignalClassifier(); + + void Initialize(int sample_rate_hz); + SignalType Analyze(rtc::ArrayView signal); + + private: + class FrameExtender { + public: + FrameExtender(size_t frame_size, size_t extended_frame_size); + ~FrameExtender(); + + void ExtendFrame(rtc::ArrayView x, + rtc::ArrayView x_extended); + + private: + std::vector x_old_; + + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(FrameExtender); + }; + + ApmDataDumper* const data_dumper_; + DownSampler down_sampler_; + std::unique_ptr frame_extender_; + NoiseSpectrumEstimator noise_spectrum_estimator_; + int sample_rate_hz_; + int initialization_frames_left_; + int consistent_classification_counter_; + SignalType last_signal_type_; + const OouraFft ooura_fft_; + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(SignalClassifier); +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_SIGNAL_CLASSIFIER_H_ diff --git a/modules/audio_processing/agc2/signal_classifier_unittest.cc b/modules/audio_processing/agc2/signal_classifier_unittest.cc new file mode 100644 index 0000000000..62171b32e6 --- /dev/null +++ b/modules/audio_processing/agc2/signal_classifier_unittest.cc @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/signal_classifier.h" + +#include +#include +#include + +#include "modules/audio_processing/agc2/agc2_testing_common.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/gunit.h" +#include "rtc_base/random.h" + +namespace webrtc { + +namespace { +Random rand_gen(42); +ApmDataDumper data_dumper(0); +constexpr int kNumIterations = 100; + +// Runs the signal classifier on audio generated by 'sample_generator' +// for kNumIterations. Returns the number of frames classified as noise. +int RunClassifier(std::function sample_generator, int rate) { + SignalClassifier classifier(&data_dumper); + std::array signal; + classifier.Initialize(rate); + const size_t samples_per_channel = rtc::CheckedDivExact(rate, 100); + int number_of_noise_frames = 0; + for (int i = 0; i < kNumIterations; ++i) { + for (size_t j = 0; j < samples_per_channel; ++j) { + signal[j] = sample_generator(); + } + number_of_noise_frames += + classifier.Analyze({&signal[0], samples_per_channel}) == + SignalClassifier::SignalType::kStationary; + } + return number_of_noise_frames; +} + +float WhiteNoiseGenerator() { + return static_cast(rand_gen.Rand(std::numeric_limits::min(), + std::numeric_limits::max())); +} +} // namespace + +// White random noise is stationary, but does not trigger the detector +// every frame due to the randomness. +TEST(AutomaticGainController2SignalClassifier, WhiteNoise) { + for (const auto rate : {8000, 16000, 32000, 48000}) { + const int number_of_noise_frames = RunClassifier(WhiteNoiseGenerator, rate); + EXPECT_GT(number_of_noise_frames, kNumIterations / 2); + } +} + +// Sine curves are (very) stationary. They trigger the detector all +// the time. Except for a few initial frames. +TEST(AutomaticGainController2SignalClassifier, SineTone) { + for (const auto rate : {8000, 16000, 32000, 48000}) { + test::SineGenerator gen(600.f, rate); + const int number_of_noise_frames = RunClassifier(gen, rate); + EXPECT_GE(number_of_noise_frames, kNumIterations - 5); + } +} + +// Pulses are transient if they are far enough apart. They shouldn't +// trigger the noise detector. +TEST(AutomaticGainController2SignalClassifier, PulseTone) { + for (const auto rate : {8000, 16000, 32000, 48000}) { + test::PulseGenerator gen(30.f, rate); + const int number_of_noise_frames = RunClassifier(gen, rate); + EXPECT_EQ(number_of_noise_frames, 0); + } +} +} // namespace webrtc