diff --git a/api/audio/echo_canceller3_config.h b/api/audio/echo_canceller3_config.h index ed39b66420..98bd463f8c 100644 --- a/api/audio/echo_canceller3_config.h +++ b/api/audio/echo_canceller3_config.h @@ -157,6 +157,7 @@ struct EchoCanceller3Config { struct Suppressor { size_t bands_with_reliable_coherence = 5; + size_t nearend_average_blocks = 4; } suppressor; }; } // namespace webrtc diff --git a/modules/audio_processing/aec3/BUILD.gn b/modules/audio_processing/aec3/BUILD.gn index c0f632aac4..237f71aa54 100644 --- a/modules/audio_processing/aec3/BUILD.gn +++ b/modules/audio_processing/aec3/BUILD.gn @@ -69,6 +69,8 @@ rtc_static_library("aec3") { "matched_filter_lag_aggregator.h", "matrix_buffer.cc", "matrix_buffer.h", + "moving_average.cc", + "moving_average.h", "render_buffer.cc", "render_buffer.h", "render_delay_buffer.cc", diff --git a/modules/audio_processing/aec3/echo_canceller3.cc b/modules/audio_processing/aec3/echo_canceller3.cc index 3acb31b259..0caf179627 100644 --- a/modules/audio_processing/aec3/echo_canceller3.cc +++ b/modules/audio_processing/aec3/echo_canceller3.cc @@ -41,6 +41,11 @@ bool EnableReverbModelling() { return !field_trial::IsEnabled("WebRTC-Aec3ReverbModellingKillSwitch"); } +bool EnableSuppressorNearendAveraging() { + return !field_trial::IsEnabled( + "WebRTC-Aec3SuppressorNearendAveragingKillSwitch"); +} + // Method for adjusting config parameter dependencies.. EchoCanceller3Config AdjustConfig(const EchoCanceller3Config& config) { EchoCanceller3Config adjusted_cfg = config; @@ -89,6 +94,10 @@ EchoCanceller3Config AdjustConfig(const EchoCanceller3Config& config) { adjusted_cfg.ep_strength.reverb_based_on_render = false; } + if (!EnableSuppressorNearendAveraging()) { + adjusted_cfg.suppressor.nearend_average_blocks = 1; + } + return adjusted_cfg; } diff --git a/modules/audio_processing/aec3/moving_average.cc b/modules/audio_processing/aec3/moving_average.cc new file mode 100644 index 0000000000..e9d64e6b32 --- /dev/null +++ b/modules/audio_processing/aec3/moving_average.cc @@ -0,0 +1,58 @@ + +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/moving_average.h" + +#include +#include + +namespace webrtc { +namespace aec3 { + +MovingAverage::MovingAverage(size_t num_elem, size_t mem_len) + : num_elem_(num_elem), + mem_len_(mem_len - 1), + scaling_(1.0f / static_cast(mem_len)), + memory_(num_elem * mem_len_, 0.f), + mem_index_(0) { + RTC_DCHECK(num_elem_ > 0); + RTC_DCHECK(mem_len > 0); +} + +MovingAverage::~MovingAverage() = default; + +void MovingAverage::Average(rtc::ArrayView input, + rtc::ArrayView output) { + RTC_DCHECK(input.size() == num_elem_); + RTC_DCHECK(output.size() == num_elem_); + + // Sum all contributions. + std::copy(input.begin(), input.end(), output.begin()); + for (auto i = memory_.begin(); i < memory_.end(); i += num_elem_) { + std::transform(i, i + num_elem_, output.begin(), output.begin(), + std::plus()); + } + + // Divide by mem_len_. + for (float& o : output) { + o *= scaling_; + } + + // Update memory. + if (mem_len_ > 0) { + std::copy(input.begin(), input.end(), + memory_.begin() + mem_index_ * num_elem_); + mem_index_ = (mem_index_ + 1) % mem_len_; + } +} + +} // namespace aec3 +} // namespace webrtc diff --git a/modules/audio_processing/aec3/moving_average.h b/modules/audio_processing/aec3/moving_average.h new file mode 100644 index 0000000000..94497d782c --- /dev/null +++ b/modules/audio_processing/aec3/moving_average.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_MOVING_AVERAGE_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_MOVING_AVERAGE_H_ + +#include + +#include "api/array_view.h" + +namespace webrtc { +namespace aec3 { + +class MovingAverage { + public: + // Creates an instance of MovingAverage that accepts inputs of length num_elem + // and averages over mem_len inputs. + MovingAverage(size_t num_elem, size_t mem_len); + ~MovingAverage(); + + // Computes the average of input and mem_len-1 previous inputs and stores the + // result in output. + void Average(rtc::ArrayView input, rtc::ArrayView output); + + private: + const size_t num_elem_; + const size_t mem_len_; + const float scaling_; + std::vector memory_; + size_t mem_index_; +}; + +} // namespace aec3 +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_MOVING_AVERAGE_H_ diff --git a/modules/audio_processing/aec3/suppression_gain.cc b/modules/audio_processing/aec3/suppression_gain.cc index 2098fc25f8..6b86b877d3 100644 --- a/modules/audio_processing/aec3/suppression_gain.cc +++ b/modules/audio_processing/aec3/suppression_gain.cc @@ -20,6 +20,7 @@ #include #include +#include "modules/audio_processing/aec3/moving_average.h" #include "modules/audio_processing/aec3/vector_math.h" #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/atomicops.h" @@ -386,7 +387,9 @@ SuppressionGain::SuppressionGain(const EchoCanceller3Config& config, static_cast(config_.filter.config_change_duration_blocks)), coherence_gain_(sample_rate_hz, config_.suppressor.bands_with_reliable_coherence), - enable_transparency_improvements_(EnableTransparencyImprovements()) { + enable_transparency_improvements_(EnableTransparencyImprovements()), + moving_average_(kFftLengthBy2Plus1, + config.suppressor.nearend_average_blocks) { RTC_DCHECK_LT(0, state_change_duration_blocks_); one_by_state_change_duration_blocks_ = 1.f / state_change_duration_blocks_; last_gain_.fill(1.f); @@ -413,11 +416,14 @@ void SuppressionGain::GetGain( RTC_DCHECK(high_bands_gain); RTC_DCHECK(low_band_gain); + std::array nearend_average; + moving_average_.Average(nearend_spectrum, nearend_average); + // Compute gain for the lower band. bool low_noise_render = low_render_detector_.Detect(render); const absl::optional narrow_peak_band = render_signal_analyzer.NarrowPeakBand(); - LowerBandGain(low_noise_render, aec_state, nearend_spectrum, echo_spectrum, + LowerBandGain(low_noise_render, aec_state, nearend_average, echo_spectrum, comfort_noise_spectrum, low_band_gain); // Adjust the gain for bands where the coherence indicates not echo. diff --git a/modules/audio_processing/aec3/suppression_gain.h b/modules/audio_processing/aec3/suppression_gain.h index f3719eed4e..ced3666fb5 100644 --- a/modules/audio_processing/aec3/suppression_gain.h +++ b/modules/audio_processing/aec3/suppression_gain.h @@ -18,6 +18,7 @@ #include "modules/audio_processing/aec3/aec3_common.h" #include "modules/audio_processing/aec3/aec_state.h" #include "modules/audio_processing/aec3/coherence_gain.h" +#include "modules/audio_processing/aec3/moving_average.h" #include "modules/audio_processing/aec3/render_signal_analyzer.h" #include "rtc_base/constructormagic.h" @@ -85,6 +86,7 @@ class SuppressionGain { int initial_state_change_counter_ = 0; CoherenceGain coherence_gain_; const bool enable_transparency_improvements_; + aec3::MovingAverage moving_average_; RTC_DISALLOW_COPY_AND_ASSIGN(SuppressionGain); };