diff --git a/api/audio/echo_canceller3_config.h b/api/audio/echo_canceller3_config.h index 1acb26fa89..fd5bf0963e 100644 --- a/api/audio/echo_canceller3_config.h +++ b/api/audio/echo_canceller3_config.h @@ -51,6 +51,8 @@ struct EchoCanceller3Config { MainConfiguration main_initial = {12, 0.05f, 5.f, 0.001f, 20075344.f}; ShadowConfiguration shadow_initial = {12, 0.9f, 20075344.f}; + + size_t config_change_duration_blocks = 250; } filter; struct Erle { diff --git a/modules/audio_processing/aec3/adaptive_fir_filter.cc b/modules/audio_processing/aec3/adaptive_fir_filter.cc index b3d8ca56bd..9bea40bbf1 100644 --- a/modules/audio_processing/aec3/adaptive_fir_filter.cc +++ b/modules/audio_processing/aec3/adaptive_fir_filter.cc @@ -416,16 +416,27 @@ void ApplyFilter_SSE2(const RenderBuffer& render_buffer, } // namespace aec3 AdaptiveFirFilter::AdaptiveFirFilter(size_t max_size_partitions, + size_t initial_size_partitions, + size_t size_change_duration_blocks, Aec3Optimization optimization, ApmDataDumper* data_dumper) : data_dumper_(data_dumper), fft_(), optimization_(optimization), max_size_partitions_(max_size_partitions), + size_change_duration_blocks_( + static_cast(size_change_duration_blocks)), + current_size_partitions_(initial_size_partitions), + target_size_partitions_(initial_size_partitions), + old_target_size_partitions_(initial_size_partitions), H_(max_size_partitions_), H2_(max_size_partitions_, std::array()), h_(GetTimeDomainLength(max_size_partitions_), 0.f) { RTC_DCHECK(data_dumper_); + RTC_DCHECK_GE(max_size_partitions, initial_size_partitions); + + RTC_DCHECK_LT(0, size_change_duration_blocks_); + one_by_size_change_duration_blocks_ = 1.f / size_change_duration_blocks_; for (auto& H_j : H_) { H_j.Clear(); @@ -434,6 +445,7 @@ AdaptiveFirFilter::AdaptiveFirFilter(size_t max_size_partitions, H2_k.fill(0.f); } erl_.fill(0.f); + SetSizePartitions(current_size_partitions_, true); } AdaptiveFirFilter::~AdaptiveFirFilter() = default; @@ -460,7 +472,7 @@ void AdaptiveFirFilter::HandleEchoPathChange() { erl_.fill(0.f); } -void AdaptiveFirFilter::SetSizePartitions(size_t size) { +void AdaptiveFirFilter::SetSizePartitions(size_t size, bool immediate_effect) { RTC_DCHECK_EQ(max_size_partitions_, H_.capacity()); RTC_DCHECK_EQ(max_size_partitions_, H2_.capacity()); RTC_DCHECK_EQ(GetTimeDomainLength(max_size_partitions_), h_.capacity()); @@ -468,22 +480,53 @@ void AdaptiveFirFilter::SetSizePartitions(size_t size) { RTC_DCHECK_EQ(h_.size(), GetTimeDomainLength(H_.size())); RTC_DCHECK_LE(size, max_size_partitions_); - if (size > max_size_partitions_) { - size = max_size_partitions_; + target_size_partitions_ = std::min(max_size_partitions_, size); + if (immediate_effect) { + current_size_partitions_ = old_target_size_partitions_ = + target_size_partitions_; + ResetFilterBuffersToCurrentSize(); + size_change_counter_ = 0; + } else { + size_change_counter_ = size_change_duration_blocks_; } +} - if (size < H_.size()) { - for (size_t k = size; k < H_.size(); ++k) { +void AdaptiveFirFilter::ResetFilterBuffersToCurrentSize() { + if (current_size_partitions_ < H_.size()) { + for (size_t k = current_size_partitions_; k < H_.size(); ++k) { H_[k].Clear(); H2_[k].fill(0.f); } - - std::fill(h_.begin() + GetTimeDomainLength(size), h_.end(), 0.f); + std::fill(h_.begin() + GetTimeDomainLength(current_size_partitions_), + h_.end(), 0.f); } - H_.resize(size); - H2_.resize(size); - h_.resize(GetTimeDomainLength(size)); + H_.resize(current_size_partitions_); + H2_.resize(current_size_partitions_); + h_.resize(GetTimeDomainLength(current_size_partitions_)); +} + +void AdaptiveFirFilter::UpdateSize() { + RTC_DCHECK_GE(size_change_duration_blocks_, size_change_counter_); + if (size_change_counter_ > 0) { + --size_change_counter_; + + auto average = [](float from, float to, float from_weight) { + return from * from_weight + to * (1.f - from_weight); + }; + + float change_factor = + size_change_counter_ * one_by_size_change_duration_blocks_; + + current_size_partitions_ = average(old_target_size_partitions_, + target_size_partitions_, change_factor); + + ResetFilterBuffersToCurrentSize(); + } else { + current_size_partitions_ = old_target_size_partitions_ = + target_size_partitions_; + } + RTC_DCHECK_LE(0, size_change_counter_); } void AdaptiveFirFilter::Filter(const RenderBuffer& render_buffer, @@ -507,6 +550,9 @@ void AdaptiveFirFilter::Filter(const RenderBuffer& render_buffer, void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer, const FftData& G) { + // Update the filter size if needed. + UpdateSize(); + // Adapt the filter. switch (optimization_) { #if defined(WEBRTC_ARCH_X86_FAMILY) diff --git a/modules/audio_processing/aec3/adaptive_fir_filter.h b/modules/audio_processing/aec3/adaptive_fir_filter.h index 0a9828c119..1e128b57f6 100644 --- a/modules/audio_processing/aec3/adaptive_fir_filter.h +++ b/modules/audio_processing/aec3/adaptive_fir_filter.h @@ -92,6 +92,8 @@ void ApplyFilter_SSE2(const RenderBuffer& render_buffer, class AdaptiveFirFilter { public: AdaptiveFirFilter(size_t max_size_partitions, + size_t initial_size_partitions, + size_t size_change_duration_blocks, Aec3Optimization optimization, ApmDataDumper* data_dumper); @@ -111,7 +113,7 @@ class AdaptiveFirFilter { size_t SizePartitions() const { return H_.size(); } // Sets the filter size. - void SetSizePartitions(size_t size); + void SetSizePartitions(size_t size, bool immediate_effect); // Returns the filter based echo return loss. const std::array& Erl() const { return erl_; } @@ -145,10 +147,22 @@ class AdaptiveFirFilter { // Constrain the filter partitions in a cyclic manner. void Constrain(); + // Resets the filter buffers to use the current size. + void ResetFilterBuffersToCurrentSize(); + + // Gradually Updates the current filter size towards the target size. + void UpdateSize(); + ApmDataDumper* const data_dumper_; const Aec3Fft fft_; const Aec3Optimization optimization_; const size_t max_size_partitions_; + const int size_change_duration_blocks_; + float one_by_size_change_duration_blocks_; + size_t current_size_partitions_; + size_t target_size_partitions_; + size_t old_target_size_partitions_; + int size_change_counter_ = 0; std::vector H_; std::vector> H2_; std::vector h_; diff --git a/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc b/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc index ae283ed4fb..9fb11cd508 100644 --- a/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc +++ b/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc @@ -268,13 +268,13 @@ TEST(AdaptiveFirFilter, UpdateErlSse2Optimization) { #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) // Verifies that the check for non-null data dumper works. TEST(AdaptiveFirFilter, NullDataDumper) { - EXPECT_DEATH(AdaptiveFirFilter(9, DetectOptimization(), nullptr), ""); + EXPECT_DEATH(AdaptiveFirFilter(9, 9, 250, DetectOptimization(), nullptr), ""); } // Verifies that the check for non-null filter output works. TEST(AdaptiveFirFilter, NullFilterOutput) { ApmDataDumper data_dumper(42); - AdaptiveFirFilter filter(9, DetectOptimization(), &data_dumper); + AdaptiveFirFilter filter(9, 9, 250, DetectOptimization(), &data_dumper); std::unique_ptr render_delay_buffer( RenderDelayBuffer::Create(EchoCanceller3Config(), 3)); EXPECT_DEATH(filter.Filter(*render_delay_buffer->GetRenderBuffer(), nullptr), @@ -287,7 +287,7 @@ TEST(AdaptiveFirFilter, NullFilterOutput) { // are turned on. TEST(AdaptiveFirFilter, FilterStatisticsAccess) { ApmDataDumper data_dumper(42); - AdaptiveFirFilter filter(9, DetectOptimization(), &data_dumper); + AdaptiveFirFilter filter(9, 9, 250, DetectOptimization(), &data_dumper); filter.Erl(); filter.FilterFrequencyResponse(); } @@ -296,7 +296,8 @@ TEST(AdaptiveFirFilter, FilterStatisticsAccess) { TEST(AdaptiveFirFilter, FilterSize) { ApmDataDumper data_dumper(42); for (size_t filter_size = 1; filter_size < 5; ++filter_size) { - AdaptiveFirFilter filter(filter_size, DetectOptimization(), &data_dumper); + AdaptiveFirFilter filter(filter_size, filter_size, 250, + DetectOptimization(), &data_dumper); EXPECT_EQ(filter_size, filter.SizePartitions()); } } @@ -308,13 +309,16 @@ TEST(AdaptiveFirFilter, FilterAndAdapt) { ApmDataDumper data_dumper(42); EchoCanceller3Config config; AdaptiveFirFilter filter(config.filter.main.length_blocks, + config.filter.main.length_blocks, + config.filter.config_change_duration_blocks, DetectOptimization(), &data_dumper); Aec3Fft fft; config.delay.min_echo_path_delay_blocks = 0; config.delay.default_delay = 1; std::unique_ptr render_delay_buffer( RenderDelayBuffer::Create(config, 3)); - ShadowFilterUpdateGain gain(config.filter.shadow); + ShadowFilterUpdateGain gain(config.filter.shadow, + config.filter.config_change_duration_blocks); Random random_generator(42U); std::vector> x(3, std::vector(kBlockSize, 0.f)); std::vector n(kBlockSize, 0.f); diff --git a/modules/audio_processing/aec3/echo_remover.cc b/modules/audio_processing/aec3/echo_remover.cc index b8a8ac0fc3..da1fa4be49 100644 --- a/modules/audio_processing/aec3/echo_remover.cc +++ b/modules/audio_processing/aec3/echo_remover.cc @@ -150,6 +150,7 @@ void EchoRemoverImpl::ProcessCapture( if (echo_path_variability.AudioPathChanged()) { subtractor_.HandleEchoPathChange(echo_path_variability); aec_state_.HandleEchoPathChange(echo_path_variability); + suppression_gain_.SetInitialState(true); initial_state_ = true; } @@ -173,6 +174,7 @@ void EchoRemoverImpl::ProcessCapture( // Perform linear echo cancellation. if (initial_state_ && !aec_state_.InitialState()) { subtractor_.ExitInitialState(); + suppression_gain_.SetInitialState(false); initial_state_ = false; } subtractor_.Process(*render_buffer, y0, render_signal_analyzer_, aec_state_, diff --git a/modules/audio_processing/aec3/main_filter_update_gain.cc b/modules/audio_processing/aec3/main_filter_update_gain.cc index 7f7e1b389c..6aa57802a9 100644 --- a/modules/audio_processing/aec3/main_filter_update_gain.cc +++ b/modules/audio_processing/aec3/main_filter_update_gain.cc @@ -29,12 +29,17 @@ constexpr int kPoorExcitationCounterInitial = 1000; int MainFilterUpdateGain::instance_count_ = 0; MainFilterUpdateGain::MainFilterUpdateGain( - const EchoCanceller3Config::Filter::MainConfiguration& config) + const EchoCanceller3Config::Filter::MainConfiguration& config, + size_t config_change_duration_blocks) : data_dumper_( new ApmDataDumper(rtc::AtomicOps::Increment(&instance_count_))), - config_(config), + config_change_duration_blocks_( + static_cast(config_change_duration_blocks)), poor_excitation_counter_(kPoorExcitationCounterInitial) { + SetConfig(config, true); H_error_.fill(kHErrorInitial); + RTC_DCHECK_LT(0, config_change_duration_blocks_); + one_by_config_change_duration_blocks_ = 1.f / config_change_duration_blocks_; } MainFilterUpdateGain::~MainFilterUpdateGain() {} @@ -63,9 +68,10 @@ void MainFilterUpdateGain::Compute( const size_t size_partitions = filter.SizePartitions(); auto X2 = render_power; const auto& erl = filter.Erl(); - ++call_counter_; + UpdateCurrentConfig(); + if (render_signal_analyzer.PoorSignalExcitation()) { poor_excitation_counter_ = 0; } @@ -80,7 +86,7 @@ void MainFilterUpdateGain::Compute( std::array mu; // mu = H_error / (0.5* H_error* X2 + n * E2). for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { - mu[k] = X2[k] > config_.noise_gate + mu[k] = X2[k] > current_config_.noise_gate ? H_error_[k] / (0.5f * H_error_[k] * X2[k] + size_partitions * E2_main[k]) : 0.f; @@ -105,17 +111,47 @@ void MainFilterUpdateGain::Compute( std::array H_error_increase; std::transform(E2_shadow.begin(), E2_shadow.end(), E2_main.begin(), H_error_increase.begin(), [&](float a, float b) { - return a >= b ? config_.leakage_converged - : config_.leakage_diverged; + return a >= b ? current_config_.leakage_converged + : current_config_.leakage_diverged; }); std::transform(erl.begin(), erl.end(), H_error_increase.begin(), H_error_increase.begin(), std::multiplies()); std::transform(H_error_.begin(), H_error_.end(), H_error_increase.begin(), H_error_.begin(), [&](float a, float b) { - return std::max(a + b, config_.error_floor); + return std::max(a + b, current_config_.error_floor); }); data_dumper_->DumpRaw("aec3_main_gain_H_error", H_error_); } +void MainFilterUpdateGain::UpdateCurrentConfig() { + RTC_DCHECK_GE(config_change_duration_blocks_, config_change_counter_); + if (config_change_counter_ > 0) { + if (--config_change_counter_ > 0) { + auto average = [](float from, float to, float from_weight) { + return from * from_weight + to * (1.f - from_weight); + }; + + float change_factor = + config_change_counter_ * one_by_config_change_duration_blocks_; + + current_config_.leakage_converged = + average(old_target_config_.leakage_converged, + target_config_.leakage_converged, change_factor); + current_config_.leakage_diverged = + average(old_target_config_.leakage_diverged, + target_config_.leakage_diverged, change_factor); + current_config_.error_floor = + average(old_target_config_.error_floor, target_config_.error_floor, + change_factor); + current_config_.noise_gate = + average(old_target_config_.noise_gate, target_config_.noise_gate, + change_factor); + } else { + current_config_ = old_target_config_ = target_config_; + } + } + RTC_DCHECK_LE(0, config_change_counter_); +} + } // namespace webrtc diff --git a/modules/audio_processing/aec3/main_filter_update_gain.h b/modules/audio_processing/aec3/main_filter_update_gain.h index 7843a37b3b..525b52279c 100644 --- a/modules/audio_processing/aec3/main_filter_update_gain.h +++ b/modules/audio_processing/aec3/main_filter_update_gain.h @@ -1,6 +1,6 @@ /* * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. -spect * + * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source * tree. An additional intellectual property rights grant can be found @@ -26,11 +26,12 @@ namespace webrtc { class ApmDataDumper; -// Provides functionality for computing the adaptive gain for the main filter. +// Provides functionality for computing the adaptive gain for the main filter. class MainFilterUpdateGain { public: explicit MainFilterUpdateGain( - const EchoCanceller3Config::Filter::MainConfiguration& config); + const EchoCanceller3Config::Filter::MainConfiguration& config, + size_t config_change_duration_blocks); ~MainFilterUpdateGain(); // Takes action in the case of a known echo path change. @@ -45,18 +46,34 @@ class MainFilterUpdateGain { FftData* gain_fft); // Sets a new config. - void SetConfig( - const EchoCanceller3Config::Filter::MainConfiguration& config) { - config_ = config; + void SetConfig(const EchoCanceller3Config::Filter::MainConfiguration& config, + bool immediate_effect) { + if (immediate_effect) { + old_target_config_ = current_config_ = target_config_ = config; + config_change_counter_ = 0; + } else { + old_target_config_ = current_config_; + target_config_ = config; + config_change_counter_ = config_change_duration_blocks_; + } } private: static int instance_count_; std::unique_ptr data_dumper_; - EchoCanceller3Config::Filter::MainConfiguration config_; + const int config_change_duration_blocks_; + float one_by_config_change_duration_blocks_; + EchoCanceller3Config::Filter::MainConfiguration current_config_; + EchoCanceller3Config::Filter::MainConfiguration target_config_; + EchoCanceller3Config::Filter::MainConfiguration old_target_config_; std::array H_error_; size_t poor_excitation_counter_; size_t call_counter_ = 0; + int config_change_counter_ = 0; + + // Updates the current config towards the target config. + void UpdateCurrentConfig(); + RTC_DISALLOW_COPY_AND_ASSIGN(MainFilterUpdateGain); }; diff --git a/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc b/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc index bfcf9f87ec..13747d42ac 100644 --- a/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc +++ b/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc @@ -45,14 +45,20 @@ void RunFilterUpdateTest(int num_blocks_to_process, config.filter.main.length_blocks = filter_length_blocks; config.filter.shadow.length_blocks = filter_length_blocks; AdaptiveFirFilter main_filter(config.filter.main.length_blocks, + config.filter.main.length_blocks, + config.filter.config_change_duration_blocks, DetectOptimization(), &data_dumper); AdaptiveFirFilter shadow_filter(config.filter.shadow.length_blocks, + config.filter.shadow.length_blocks, + config.filter.config_change_duration_blocks, DetectOptimization(), &data_dumper); Aec3Fft fft; std::array x_old; x_old.fill(0.f); - ShadowFilterUpdateGain shadow_gain(config.filter.shadow); - MainFilterUpdateGain main_gain(config.filter.main); + ShadowFilterUpdateGain shadow_gain( + config.filter.shadow, config.filter.config_change_duration_blocks); + MainFilterUpdateGain main_gain(config.filter.main, + config.filter.config_change_duration_blocks); Random random_generator(42U); std::vector> x(3, std::vector(kBlockSize, 0.f)); std::vector y(kBlockSize, 0.f); @@ -189,10 +195,13 @@ TEST(MainFilterUpdateGain, NullDataOutputGain) { ApmDataDumper data_dumper(42); EchoCanceller3Config config; AdaptiveFirFilter filter(config.filter.main.length_blocks, + config.filter.main.length_blocks, + config.filter.config_change_duration_blocks, DetectOptimization(), &data_dumper); RenderSignalAnalyzer analyzer(EchoCanceller3Config{}); SubtractorOutput output; - MainFilterUpdateGain gain(config.filter.main); + MainFilterUpdateGain gain(config.filter.main, + config.filter.config_change_duration_blocks); std::array render_power; render_power.fill(0.f); EXPECT_DEATH( diff --git a/modules/audio_processing/aec3/shadow_filter_update_gain.cc b/modules/audio_processing/aec3/shadow_filter_update_gain.cc index 0fc940a555..e27437aff2 100644 --- a/modules/audio_processing/aec3/shadow_filter_update_gain.cc +++ b/modules/audio_processing/aec3/shadow_filter_update_gain.cc @@ -18,8 +18,14 @@ namespace webrtc { ShadowFilterUpdateGain::ShadowFilterUpdateGain( - const EchoCanceller3Config::Filter::ShadowConfiguration& config) - : config_(config) {} + const EchoCanceller3Config::Filter::ShadowConfiguration& config, + size_t config_change_duration_blocks) + : config_change_duration_blocks_( + static_cast(config_change_duration_blocks)) { + SetConfig(config, true); + RTC_DCHECK_LT(0, config_change_duration_blocks_); + one_by_config_change_duration_blocks_ = 1.f / config_change_duration_blocks_; +} void ShadowFilterUpdateGain::HandleEchoPathChange() { // TODO(peah): Check whether this counter should instead be initialized to a @@ -38,6 +44,8 @@ void ShadowFilterUpdateGain::Compute( RTC_DCHECK(G); ++call_counter_; + UpdateCurrentConfig(); + if (render_signal_analyzer.PoorSignalExcitation()) { poor_signal_excitation_counter_ = 0; } @@ -54,7 +62,7 @@ void ShadowFilterUpdateGain::Compute( std::array mu; auto X2 = render_power; std::transform(X2.begin(), X2.end(), mu.begin(), [&](float a) { - return a > config_.noise_gate ? config_.rate / a : 0.f; + return a > current_config_.noise_gate ? current_config_.rate / a : 0.f; }); // Avoid updating the filter close to narrow bands in the render signals. @@ -67,4 +75,27 @@ void ShadowFilterUpdateGain::Compute( std::multiplies()); } +void ShadowFilterUpdateGain::UpdateCurrentConfig() { + RTC_DCHECK_GE(config_change_duration_blocks_, config_change_counter_); + if (config_change_counter_ > 0) { + if (--config_change_counter_ > 0) { + auto average = [](float from, float to, float from_weight) { + return from * from_weight + to * (1.f - from_weight); + }; + + float change_factor = + config_change_counter_ * one_by_config_change_duration_blocks_; + + current_config_.rate = + average(old_target_config_.rate, target_config_.rate, change_factor); + current_config_.noise_gate = + average(old_target_config_.noise_gate, target_config_.noise_gate, + change_factor); + } else { + current_config_ = old_target_config_ = target_config_; + } + } + RTC_DCHECK_LE(0, config_change_counter_); +} + } // namespace webrtc diff --git a/modules/audio_processing/aec3/shadow_filter_update_gain.h b/modules/audio_processing/aec3/shadow_filter_update_gain.h index ce17e4faaf..a92bc3b8b7 100644 --- a/modules/audio_processing/aec3/shadow_filter_update_gain.h +++ b/modules/audio_processing/aec3/shadow_filter_update_gain.h @@ -23,7 +23,8 @@ namespace webrtc { class ShadowFilterUpdateGain { public: explicit ShadowFilterUpdateGain( - const EchoCanceller3Config::Filter::ShadowConfiguration& config); + const EchoCanceller3Config::Filter::ShadowConfiguration& config, + size_t config_change_duration_blocks); // Takes action in the case of a known echo path change. void HandleEchoPathChange(); @@ -38,16 +39,31 @@ class ShadowFilterUpdateGain { // Sets a new config. void SetConfig( - const EchoCanceller3Config::Filter::ShadowConfiguration& config) { - config_ = config; + const EchoCanceller3Config::Filter::ShadowConfiguration& config, + bool immediate_effect) { + if (immediate_effect) { + old_target_config_ = current_config_ = target_config_ = config; + config_change_counter_ = 0; + } else { + old_target_config_ = current_config_; + target_config_ = config; + config_change_counter_ = config_change_duration_blocks_; + } } private: - EchoCanceller3Config::Filter::ShadowConfiguration config_; + EchoCanceller3Config::Filter::ShadowConfiguration current_config_; + EchoCanceller3Config::Filter::ShadowConfiguration target_config_; + EchoCanceller3Config::Filter::ShadowConfiguration old_target_config_; + const int config_change_duration_blocks_; + float one_by_config_change_duration_blocks_; // TODO(peah): Check whether this counter should instead be initialized to a // large value. size_t poor_signal_excitation_counter_ = 0; size_t call_counter_ = 0; + int config_change_counter_ = 0; + + void UpdateCurrentConfig(); }; } // namespace webrtc diff --git a/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc b/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc index 09e155b699..d77da33f10 100644 --- a/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc +++ b/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc @@ -40,8 +40,12 @@ void RunFilterUpdateTest(int num_blocks_to_process, EchoCanceller3Config config; config.filter.main.length_blocks = filter_length_blocks; AdaptiveFirFilter main_filter(config.filter.main.length_blocks, + config.filter.main.length_blocks, + config.filter.config_change_duration_blocks, DetectOptimization(), &data_dumper); AdaptiveFirFilter shadow_filter(config.filter.shadow.length_blocks, + config.filter.shadow.length_blocks, + config.filter.config_change_duration_blocks, DetectOptimization(), &data_dumper); Aec3Fft fft; @@ -52,7 +56,8 @@ void RunFilterUpdateTest(int num_blocks_to_process, std::array x_old; x_old.fill(0.f); - ShadowFilterUpdateGain shadow_gain(config.filter.shadow); + ShadowFilterUpdateGain shadow_gain( + config.filter.shadow, config.filter.config_change_duration_blocks); Random random_generator(42U); std::vector> x(3, std::vector(kBlockSize, 0.f)); std::vector y(kBlockSize, 0.f); @@ -134,7 +139,7 @@ TEST(ShadowFilterUpdateGain, NullDataOutputGain) { FftData E; const EchoCanceller3Config::Filter::ShadowConfiguration& config = { 12, 0.5f, 220075344.f}; - ShadowFilterUpdateGain gain(config); + ShadowFilterUpdateGain gain(config, 250); std::array render_power; render_power.fill(0.f); EXPECT_DEATH(gain.Compute(render_power, analyzer, E, 1, false, nullptr), ""); diff --git a/modules/audio_processing/aec3/subtractor.cc b/modules/audio_processing/aec3/subtractor.cc index 7a4e3ce977..b6a68affe3 100644 --- a/modules/audio_processing/aec3/subtractor.cc +++ b/modules/audio_processing/aec3/subtractor.cc @@ -62,13 +62,19 @@ Subtractor::Subtractor(const EchoCanceller3Config& config, optimization_(optimization), config_(config), main_filter_(config_.filter.main.length_blocks, + config_.filter.main_initial.length_blocks, + config.filter.config_change_duration_blocks, optimization, data_dumper_), shadow_filter_(config_.filter.shadow.length_blocks, + config_.filter.shadow_initial.length_blocks, + config.filter.config_change_duration_blocks, optimization, data_dumper_), - G_main_(config_.filter.main_initial), - G_shadow_(config_.filter.shadow_initial) { + G_main_(config_.filter.main_initial, + config_.filter.config_change_duration_blocks), + G_shadow_(config_.filter.shadow_initial, + config.filter.config_change_duration_blocks) { RTC_DCHECK(data_dumper_); // Currently, the rest of AEC3 requires the main and shadow filter lengths to // be identical. @@ -76,14 +82,6 @@ Subtractor::Subtractor(const EchoCanceller3Config& config, config_.filter.shadow.length_blocks); RTC_DCHECK_EQ(config_.filter.main_initial.length_blocks, config_.filter.shadow_initial.length_blocks); - - RTC_DCHECK_GE(config_.filter.main.length_blocks, - config_.filter.main_initial.length_blocks); - RTC_DCHECK_GE(config_.filter.shadow.length_blocks, - config_.filter.shadow_initial.length_blocks); - - main_filter_.SetSizePartitions(config_.filter.main_initial.length_blocks); - shadow_filter_.SetSizePartitions(config_.filter.shadow_initial.length_blocks); } Subtractor::~Subtractor() = default; @@ -95,13 +93,14 @@ void Subtractor::HandleEchoPathChange( shadow_filter_.HandleEchoPathChange(); G_main_.HandleEchoPathChange(echo_path_variability); G_shadow_.HandleEchoPathChange(); - G_main_.SetConfig(config_.filter.main_initial); - G_shadow_.SetConfig(config_.filter.shadow_initial); + G_main_.SetConfig(config_.filter.main_initial, true); + G_shadow_.SetConfig(config_.filter.shadow_initial, true); main_filter_converged_ = false; shadow_filter_converged_ = false; - main_filter_.SetSizePartitions(config_.filter.main_initial.length_blocks); + main_filter_.SetSizePartitions(config_.filter.main_initial.length_blocks, + true); shadow_filter_.SetSizePartitions( - config_.filter.shadow_initial.length_blocks); + config_.filter.shadow_initial.length_blocks, true); }; // TODO(peah): Add delay-change specific reset behavior. @@ -120,10 +119,10 @@ void Subtractor::HandleEchoPathChange( } void Subtractor::ExitInitialState() { - G_main_.SetConfig(config_.filter.main); - G_shadow_.SetConfig(config_.filter.shadow); - main_filter_.SetSizePartitions(config_.filter.main.length_blocks); - shadow_filter_.SetSizePartitions(config_.filter.shadow.length_blocks); + G_main_.SetConfig(config_.filter.main, false); + G_shadow_.SetConfig(config_.filter.shadow, false); + main_filter_.SetSizePartitions(config_.filter.main.length_blocks, false); + shadow_filter_.SetSizePartitions(config_.filter.shadow.length_blocks, false); } void Subtractor::Process(const RenderBuffer& render_buffer, diff --git a/modules/audio_processing/aec3/suppression_gain.cc b/modules/audio_processing/aec3/suppression_gain.cc index 0962912d7f..53fd5758c7 100644 --- a/modules/audio_processing/aec3/suppression_gain.cc +++ b/modules/audio_processing/aec3/suppression_gain.cc @@ -107,78 +107,6 @@ float UpperBandsGain( return std::min(gain_below_8_khz, anti_howling_gain); } -// Limits the gain increase. -void UpdateMaxGainIncrease( - const EchoCanceller3Config& config, - size_t no_saturation_counter, - bool low_noise_render, - bool initial_state, - bool linear_echo_estimate, - const std::array& last_echo, - const std::array& echo, - const std::array& last_gain, - const std::array& new_gain, - std::array* gain_increase) { - float max_increasing; - float max_decreasing; - float rate_increasing; - float rate_decreasing; - float min_increasing; - float min_decreasing; - - auto& param = config.gain_updates; - if (!linear_echo_estimate) { - max_increasing = param.nonlinear.max_inc; - max_decreasing = param.nonlinear.max_dec; - rate_increasing = param.nonlinear.rate_inc; - rate_decreasing = param.nonlinear.rate_dec; - min_increasing = param.nonlinear.min_inc; - min_decreasing = param.nonlinear.min_dec; - } else if (initial_state && no_saturation_counter > 10) { - max_increasing = param.initial.max_inc; - max_decreasing = param.initial.max_dec; - rate_increasing = param.initial.rate_inc; - rate_decreasing = param.initial.rate_dec; - min_increasing = param.initial.min_inc; - min_decreasing = param.initial.min_dec; - } else if (low_noise_render) { - max_increasing = param.low_noise.max_inc; - max_decreasing = param.low_noise.max_dec; - rate_increasing = param.low_noise.rate_inc; - rate_decreasing = param.low_noise.rate_dec; - min_increasing = param.low_noise.min_inc; - min_decreasing = param.low_noise.min_dec; - } else if (no_saturation_counter > 10) { - max_increasing = param.normal.max_inc; - max_decreasing = param.normal.max_dec; - rate_increasing = param.normal.rate_inc; - rate_decreasing = param.normal.rate_dec; - min_increasing = param.normal.min_inc; - min_decreasing = param.normal.min_dec; - } else { - max_increasing = param.saturation.max_inc; - max_decreasing = param.saturation.max_dec; - rate_increasing = param.saturation.rate_inc; - rate_decreasing = param.saturation.rate_dec; - min_increasing = param.saturation.min_inc; - min_decreasing = param.saturation.min_dec; - } - - for (size_t k = 0; k < new_gain.size(); ++k) { - if (echo[k] > last_echo[k]) { - (*gain_increase)[k] = - new_gain[k] > last_gain[k] - ? std::min(max_increasing, (*gain_increase)[k] * rate_increasing) - : min_increasing; - } else { - (*gain_increase)[k] = - new_gain[k] > last_gain[k] - ? std::min(max_decreasing, (*gain_increase)[k] * rate_decreasing) - : min_decreasing; - } - } -} - // Computes the gain to reduce the echo to a non audible level. void GainToNoAudibleEcho( const EchoCanceller3Config& config, @@ -285,14 +213,15 @@ void AdjustNonConvergedFrequencies( void SuppressionGain::LowerBandGain( bool low_noise_render, const rtc::Optional& narrow_peak_band, - bool saturated_echo, - bool saturating_echo_path, - bool initial_state, - bool linear_echo_estimate, + const AecState& aec_state, const std::array& nearend, const std::array& echo, const std::array& comfort_noise, std::array* gain) { + const bool saturated_echo = aec_state.SaturatedEcho(); + const bool saturating_echo_path = aec_state.SaturatingEchoPath(); + const bool linear_echo_estimate = aec_state.UsableLinearEstimate(); + // Count the number of blocks since saturation. no_saturation_counter_ = saturated_echo ? 0 : no_saturation_counter_ + 1; @@ -346,9 +275,7 @@ void SuppressionGain::LowerBandGain( AdjustNonConvergedFrequencies(gain); // Update the allowed maximum gain increase. - UpdateMaxGainIncrease(config_, no_saturation_counter_, low_noise_render, - initial_state, linear_echo_estimate, last_echo_, echo, - last_gain_, *gain, &gain_increase_); + UpdateGainIncrease(low_noise_render, linear_echo_estimate, echo, *gain); // Adjust gain dynamics. const float gain_bound = @@ -366,7 +293,12 @@ void SuppressionGain::LowerBandGain( SuppressionGain::SuppressionGain(const EchoCanceller3Config& config, Aec3Optimization optimization) - : optimization_(optimization), config_(config) { + : optimization_(optimization), + config_(config), + state_change_duration_blocks_( + static_cast(config_.filter.config_change_duration_blocks)) { + RTC_DCHECK_LT(0, state_change_duration_blocks_); + one_by_state_change_duration_blocks_ = 1.f / state_change_duration_blocks_; last_gain_.fill(1.f); last_masker_.fill(0.f); gain_increase_.fill(1.f); @@ -385,21 +317,14 @@ void SuppressionGain::GetGain( RTC_DCHECK(high_bands_gain); RTC_DCHECK(low_band_gain); - const bool saturated_echo = aec_state.SaturatedEcho(); - const bool saturating_echo_path = aec_state.SaturatingEchoPath(); - const float gain_upper_bound = aec_state.SuppressionGainLimit(); - const bool linear_echo_estimate = aec_state.UsableLinearEstimate(); - const bool initial_state = aec_state.InitialState(); - - bool low_noise_render = low_render_detector_.Detect(render); - // Compute gain for the lower band. + bool low_noise_render = low_render_detector_.Detect(render); const rtc::Optional narrow_peak_band = render_signal_analyzer.NarrowPeakBand(); - LowerBandGain(low_noise_render, narrow_peak_band, saturated_echo, - saturating_echo_path, initial_state, linear_echo_estimate, - nearend, echo, comfort_noise, low_band_gain); + LowerBandGain(low_noise_render, narrow_peak_band, aec_state, nearend, echo, + comfort_noise, low_band_gain); + const float gain_upper_bound = aec_state.SuppressionGainLimit(); if (gain_upper_bound < 1.f) { for (size_t k = 0; k < low_band_gain->size(); ++k) { (*low_band_gain)[k] = std::min((*low_band_gain)[k], gain_upper_bound); @@ -407,8 +332,112 @@ void SuppressionGain::GetGain( } // Compute the gain for the upper bands. - *high_bands_gain = - UpperBandsGain(narrow_peak_band, saturated_echo, render, *low_band_gain); + *high_bands_gain = UpperBandsGain(narrow_peak_band, aec_state.SaturatedEcho(), + render, *low_band_gain); +} + +void SuppressionGain::SetInitialState(bool state) { + initial_state_ = state; + if (state) { + initial_state_change_counter_ = state_change_duration_blocks_; + } else { + initial_state_change_counter_ = 0; + } +} + +void SuppressionGain::UpdateGainIncrease( + bool low_noise_render, + bool linear_echo_estimate, + const std::array& echo, + const std::array& new_gain) { + float max_inc; + float max_dec; + float rate_inc; + float rate_dec; + float min_inc; + float min_dec; + + RTC_DCHECK_GE(state_change_duration_blocks_, initial_state_change_counter_); + if (initial_state_change_counter_ > 0) { + if (--initial_state_change_counter_ == 0) { + initial_state_ = false; + } + } + RTC_DCHECK_LE(0, initial_state_change_counter_); + + // EchoCanceller3Config::GainUpdates + auto& p = config_.gain_updates; + if (!linear_echo_estimate) { + max_inc = p.nonlinear.max_inc; + max_dec = p.nonlinear.max_dec; + rate_inc = p.nonlinear.rate_inc; + rate_dec = p.nonlinear.rate_dec; + min_inc = p.nonlinear.min_inc; + min_dec = p.nonlinear.min_dec; + } else if (initial_state_ && no_saturation_counter_ > 10) { + if (initial_state_change_counter_ > 0) { + float change_factor = + initial_state_change_counter_ * one_by_state_change_duration_blocks_; + + auto average = [](float from, float to, float from_weight) { + return from * from_weight + to * (1.f - from_weight); + }; + + max_inc = average(p.initial.max_inc, p.normal.max_inc, change_factor); + max_dec = average(p.initial.max_dec, p.normal.max_dec, change_factor); + rate_inc = average(p.initial.rate_inc, p.normal.rate_inc, change_factor); + rate_dec = average(p.initial.rate_dec, p.normal.rate_dec, change_factor); + min_inc = average(p.initial.min_inc, p.normal.min_inc, change_factor); + min_dec = average(p.initial.min_dec, p.normal.min_dec, change_factor); + } else { + max_inc = p.initial.max_inc; + max_dec = p.initial.max_dec; + rate_inc = p.initial.rate_inc; + rate_dec = p.initial.rate_dec; + min_inc = p.initial.min_inc; + min_dec = p.initial.min_dec; + } + } else if (low_noise_render) { + max_inc = p.low_noise.max_inc; + max_dec = p.low_noise.max_dec; + rate_inc = p.low_noise.rate_inc; + rate_dec = p.low_noise.rate_dec; + min_inc = p.low_noise.min_inc; + min_dec = p.low_noise.min_dec; + } else if (no_saturation_counter_ > 10) { + max_inc = p.normal.max_inc; + max_dec = p.normal.max_dec; + rate_inc = p.normal.rate_inc; + rate_dec = p.normal.rate_dec; + min_inc = p.normal.min_inc; + min_dec = p.normal.min_dec; + } else { + max_inc = p.saturation.max_inc; + max_dec = p.saturation.max_dec; + rate_inc = p.saturation.rate_inc; + rate_dec = p.saturation.rate_dec; + min_inc = p.saturation.min_inc; + min_dec = p.saturation.min_dec; + } + + for (size_t k = 0; k < new_gain.size(); ++k) { + auto increase_update = [](float new_gain, float last_gain, + float current_inc, float max_inc, float min_inc, + float change_rate) { + return new_gain > last_gain ? std::min(max_inc, current_inc * change_rate) + : min_inc; + }; + + if (echo[k] > last_echo_[k]) { + gain_increase_[k] = + increase_update(new_gain[k], last_gain_[k], gain_increase_[k], + max_inc, min_inc, rate_inc); + } else { + gain_increase_[k] = + increase_update(new_gain[k], last_gain_[k], gain_increase_[k], + max_dec, min_dec, rate_dec); + } + } } // Detects when the render signal can be considered to have low power and diff --git a/modules/audio_processing/aec3/suppression_gain.h b/modules/audio_processing/aec3/suppression_gain.h index d4cdff34b2..6624c1c05f 100644 --- a/modules/audio_processing/aec3/suppression_gain.h +++ b/modules/audio_processing/aec3/suppression_gain.h @@ -35,18 +35,25 @@ class SuppressionGain { float* high_bands_gain, std::array* low_band_gain); + // Toggles the usage of the initial state. + void SetInitialState(bool state); + private: void LowerBandGain(bool stationary_with_low_power, const rtc::Optional& narrow_peak_band, - bool saturated_echo, - bool saturating_echo_path, - bool initial_state, - bool linear_echo_estimate, + const AecState& aec_state, const std::array& nearend, const std::array& echo, const std::array& comfort_noise, std::array* gain); + // Limits the gain increase. + void UpdateGainIncrease( + bool low_noise_render, + bool linear_echo_estimate, + const std::array& echo, + const std::array& new_gain); + class LowNoiseRenderDetector { public: bool Detect(const std::vector>& render); @@ -56,6 +63,9 @@ class SuppressionGain { }; const Aec3Optimization optimization_; + const EchoCanceller3Config config_; + const int state_change_duration_blocks_; + float one_by_state_change_duration_blocks_; std::array last_gain_; std::array last_masker_; std::array gain_increase_; @@ -63,7 +73,8 @@ class SuppressionGain { LowNoiseRenderDetector low_render_detector_; size_t no_saturation_counter_ = 0; - const EchoCanceller3Config config_; + bool initial_state_ = true; + int initial_state_change_counter_ = 0; RTC_DISALLOW_COPY_AND_ASSIGN(SuppressionGain); };