From 785d4c40cac7caf62e39fac7eaa7a729d6895407 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85hgren?= Date: Thu, 17 Oct 2019 14:40:54 +0200 Subject: [PATCH] AEC3: Add multichannel support in the ERLE estimation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug: webrtc:10913 Change-Id: I1667146d38dc99d099b140f47cd774a7f203b4f0 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/157047 Reviewed-by: Sam Zackrisson Commit-Queue: Per Ã…hgren Cr-Commit-Position: refs/heads/master@{#29521} --- modules/audio_processing/aec3/aec_state.cc | 66 ++-- modules/audio_processing/aec3/aec_state.h | 3 +- .../audio_processing/aec3/erle_estimator.cc | 63 +-- .../audio_processing/aec3/erle_estimator.h | 40 +- .../aec3/erle_estimator_unittest.cc | 318 +++++++++------- .../aec3/fullband_erle_estimator.cc | 78 ++-- .../aec3/fullband_erle_estimator.h | 21 +- .../aec3/signal_dependent_erle_estimator.cc | 360 ++++++++++-------- .../aec3/signal_dependent_erle_estimator.h | 43 ++- ...ignal_dependent_erle_estimator_unittest.cc | 162 +++++--- .../aec3/subband_erle_estimator.cc | 200 ++++++---- .../aec3/subband_erle_estimator.h | 46 ++- .../aec3/subtractor_output_analyzer.cc | 43 ++- .../aec3/subtractor_output_analyzer.h | 20 +- 14 files changed, 838 insertions(+), 625 deletions(-) diff --git a/modules/audio_processing/aec3/aec_state.cc b/modules/audio_processing/aec3/aec_state.cc index 13b9bccf03..610412496d 100644 --- a/modules/audio_processing/aec3/aec_state.cc +++ b/modules/audio_processing/aec3/aec_state.cc @@ -44,7 +44,7 @@ void ComputeAvgRenderReverb( std::array X2_data; rtc::ArrayView X2; if (num_render_channels > 1) { - auto sum_channels = + auto average_channels = [](size_t num_render_channels, const std::vector>& spectrum_band_0, rtc::ArrayView render_power) { @@ -55,14 +55,18 @@ void ComputeAvgRenderReverb( render_power[k] += spectrum_band_0[ch][k]; } } + const float normalizer = 1.f / num_render_channels; + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + render_power[k] *= normalizer; + } }; - sum_channels(num_render_channels, spectrum_buffer.buffer[idx_past], - X2_data); + average_channels(num_render_channels, spectrum_buffer.buffer[idx_past], + X2_data); reverb_model->UpdateReverbNoFreqShaping( X2_data, /*power_spectrum_scaling=*/1.0f, reverb_decay); - sum_channels(num_render_channels, spectrum_buffer.buffer[idx_at_delay], - X2_data); + average_channels(num_render_channels, spectrum_buffer.buffer[idx_at_delay], + X2_data); X2 = X2_data; } else { reverb_model->UpdateReverbNoFreqShaping( @@ -110,17 +114,18 @@ AecState::AecState(const EchoCanceller3Config& config, : data_dumper_( new ApmDataDumper(rtc::AtomicOps::Increment(&instance_count_))), config_(config), + num_capture_channels_(num_capture_channels), initial_state_(config_), - delay_state_(config_, num_capture_channels), + delay_state_(config_, num_capture_channels_), transparent_state_(config_), - filter_quality_state_(config_, num_capture_channels), + filter_quality_state_(config_, num_capture_channels_), erl_estimator_(2 * kNumBlocksPerSecond), - erle_estimator_(2 * kNumBlocksPerSecond, config_, num_capture_channels), - filter_analyzer_(config_, num_capture_channels), + erle_estimator_(2 * kNumBlocksPerSecond, config_, num_capture_channels_), + filter_analyzer_(config_, num_capture_channels_), echo_audibility_( config_.echo_audibility.use_stationarity_properties_at_init), - reverb_model_estimator_(config_, num_capture_channels), - subtractor_output_analyzers_(num_capture_channels) {} + reverb_model_estimator_(config_, num_capture_channels_), + subtractor_output_analyzer_(num_capture_channels_) {} AecState::~AecState() = default; @@ -147,9 +152,7 @@ void AecState::HandleEchoPathChange( } else if (echo_path_variability.gain_change) { erle_estimator_.Reset(false); } - for (auto& analyzer : subtractor_output_analyzers_) { - analyzer.HandleEchoPathChange(); - } + subtractor_output_analyzer_.HandleEchoPathChange(); } void AecState::Update( @@ -161,25 +164,19 @@ void AecState::Update( rtc::ArrayView> E2_main, rtc::ArrayView> Y2, rtc::ArrayView subtractor_output) { - const size_t num_capture_channels = subtractor_output_analyzers_.size(); - RTC_DCHECK_EQ(num_capture_channels, E2_main.size()); - RTC_DCHECK_EQ(num_capture_channels, Y2.size()); - RTC_DCHECK_EQ(num_capture_channels, subtractor_output.size()); - RTC_DCHECK_EQ(num_capture_channels, subtractor_output_analyzers_.size()); - RTC_DCHECK_EQ(num_capture_channels, + RTC_DCHECK_EQ(num_capture_channels_, Y2.size()); + RTC_DCHECK_EQ(num_capture_channels_, subtractor_output.size()); + RTC_DCHECK_EQ(num_capture_channels_, adaptive_filter_frequency_responses.size()); - RTC_DCHECK_EQ(num_capture_channels, adaptive_filter_impulse_responses.size()); + RTC_DCHECK_EQ(num_capture_channels_, + adaptive_filter_impulse_responses.size()); // Analyze the filter outputs and filters. - bool any_filter_converged = false; - bool all_filters_diverged = true; - for (size_t ch = 0; ch < subtractor_output.size(); ++ch) { - subtractor_output_analyzers_[ch].Update(subtractor_output[ch]); - any_filter_converged = any_filter_converged || - subtractor_output_analyzers_[ch].ConvergedFilter(); - all_filters_diverged = all_filters_diverged && - subtractor_output_analyzers_[ch].DivergedFilter(); - } + bool any_filter_converged; + bool all_filters_diverged; + subtractor_output_analyzer_.Update(subtractor_output, &any_filter_converged, + &all_filters_diverged); + bool any_filter_consistent; float max_echo_path_gain; filter_analyzer_.Update(adaptive_filter_impulse_responses, render_buffer, @@ -229,16 +226,15 @@ void AecState::Update( erle_estimator_.Reset(false); } - erle_estimator_.Update(render_buffer, adaptive_filter_frequency_responses[0], - avg_render_spectrum_with_reverb, Y2[0], E2_main[0], - subtractor_output_analyzers_[0].ConvergedFilter(), - config_.erle.onset_detection); + erle_estimator_.Update(render_buffer, adaptive_filter_frequency_responses, + avg_render_spectrum_with_reverb, Y2, E2_main, + subtractor_output_analyzer_.ConvergedFilters()); // TODO(bugs.webrtc.org/10913): Take all channels into account. const auto& X2 = render_buffer.Spectrum(delay_state_.MinDirectPathFilterDelay(), /*channel=*/0); - erl_estimator_.Update(subtractor_output_analyzers_[0].ConvergedFilter(), X2, + erl_estimator_.Update(subtractor_output_analyzer_.ConvergedFilters()[0], X2, Y2[0]); // Detect and flag echo saturation. diff --git a/modules/audio_processing/aec3/aec_state.h b/modules/audio_processing/aec3/aec_state.h index 71000b4328..53b8be03e2 100644 --- a/modules/audio_processing/aec3/aec_state.h +++ b/modules/audio_processing/aec3/aec_state.h @@ -150,6 +150,7 @@ class AecState { static int instance_count_; std::unique_ptr data_dumper_; const EchoCanceller3Config config_; + const size_t num_capture_channels_; // Class for controlling the transition from the intial state, which in turn // controls when the filter parameters for the initial state should be used. @@ -314,7 +315,7 @@ class AecState { EchoAudibility echo_audibility_; ReverbModelEstimator reverb_model_estimator_; ReverbModel avg_render_reverb_; - std::vector subtractor_output_analyzers_; + SubtractorOutputAnalyzer subtractor_output_analyzer_; }; } // namespace webrtc diff --git a/modules/audio_processing/aec3/erle_estimator.cc b/modules/audio_processing/aec3/erle_estimator.cc index a3f68d175b..4d843457d3 100644 --- a/modules/audio_processing/aec3/erle_estimator.cc +++ b/modules/audio_processing/aec3/erle_estimator.cc @@ -15,14 +15,17 @@ namespace webrtc { -ErleEstimator::ErleEstimator(size_t startup_phase_length_blocks_, +ErleEstimator::ErleEstimator(size_t startup_phase_length_blocks, const EchoCanceller3Config& config, size_t num_capture_channels) - : startup_phase_length_blocks__(startup_phase_length_blocks_), - use_signal_dependent_erle_(config.erle.num_sections > 1), + : startup_phase_length_blocks_(startup_phase_length_blocks), fullband_erle_estimator_(config.erle, num_capture_channels), - subband_erle_estimator_(config, num_capture_channels), - signal_dependent_erle_estimator_(config, num_capture_channels) { + subband_erle_estimator_(config, num_capture_channels) { + if (config.erle.num_sections > 1) { + signal_dependent_erle_estimator_ = + std::make_unique(config, + num_capture_channels); + } Reset(true); } @@ -31,7 +34,9 @@ ErleEstimator::~ErleEstimator() = default; void ErleEstimator::Reset(bool delay_change) { fullband_erle_estimator_.Reset(); subband_erle_estimator_.Reset(); - signal_dependent_erle_estimator_.Reset(); + if (signal_dependent_erle_estimator_) { + signal_dependent_erle_estimator_->Reset(); + } if (delay_change) { blocks_since_reset_ = 0; } @@ -39,41 +44,43 @@ void ErleEstimator::Reset(bool delay_change) { void ErleEstimator::Update( const RenderBuffer& render_buffer, - const std::vector>& - filter_frequency_response, - rtc::ArrayView reverb_render_spectrum, - rtc::ArrayView capture_spectrum, - rtc::ArrayView subtractor_spectrum, - bool converged_filter, - bool onset_detection) { - RTC_DCHECK_EQ(kFftLengthBy2Plus1, reverb_render_spectrum.size()); - RTC_DCHECK_EQ(kFftLengthBy2Plus1, capture_spectrum.size()); - RTC_DCHECK_EQ(kFftLengthBy2Plus1, subtractor_spectrum.size()); - const auto& X2_reverb = reverb_render_spectrum; - const auto& Y2 = capture_spectrum; - const auto& E2 = subtractor_spectrum; + rtc::ArrayView>> + filter_frequency_responses, + rtc::ArrayView + avg_render_spectrum_with_reverb, + rtc::ArrayView> capture_spectra, + rtc::ArrayView> + subtractor_spectra, + const std::vector& converged_filters) { + RTC_DCHECK_EQ(subband_erle_estimator_.Erle().size(), capture_spectra.size()); + RTC_DCHECK_EQ(subband_erle_estimator_.Erle().size(), + subtractor_spectra.size()); + const auto& X2_reverb = avg_render_spectrum_with_reverb; + const auto& Y2 = capture_spectra; + const auto& E2 = subtractor_spectra; - if (++blocks_since_reset_ < startup_phase_length_blocks__) { + if (++blocks_since_reset_ < startup_phase_length_blocks_) { return; } - subband_erle_estimator_.Update(X2_reverb, Y2, E2, converged_filter, - onset_detection); + subband_erle_estimator_.Update(X2_reverb, Y2, E2, converged_filters); - if (use_signal_dependent_erle_) { - signal_dependent_erle_estimator_.Update( - render_buffer, filter_frequency_response, X2_reverb, Y2, E2, - subband_erle_estimator_.Erle(), converged_filter); + if (signal_dependent_erle_estimator_) { + signal_dependent_erle_estimator_->Update( + render_buffer, filter_frequency_responses, X2_reverb, Y2, E2, + subband_erle_estimator_.Erle(), converged_filters); } - fullband_erle_estimator_.Update(X2_reverb, Y2, E2, converged_filter); + fullband_erle_estimator_.Update(X2_reverb, Y2, E2, converged_filters); } void ErleEstimator::Dump( const std::unique_ptr& data_dumper) const { fullband_erle_estimator_.Dump(data_dumper); subband_erle_estimator_.Dump(data_dumper); - signal_dependent_erle_estimator_.Dump(data_dumper); + if (signal_dependent_erle_estimator_) { + signal_dependent_erle_estimator_->Dump(data_dumper); + } } } // namespace webrtc diff --git a/modules/audio_processing/aec3/erle_estimator.h b/modules/audio_processing/aec3/erle_estimator.h index cac6741226..d741cff3da 100644 --- a/modules/audio_processing/aec3/erle_estimator.h +++ b/modules/audio_processing/aec3/erle_estimator.h @@ -15,6 +15,7 @@ #include #include +#include #include "absl/types/optional.h" #include "api/array_view.h" @@ -32,7 +33,7 @@ namespace webrtc { // and another one is done using the aggreation of energy over all the subbands. class ErleEstimator { public: - ErleEstimator(size_t startup_phase_length_blocks_, + ErleEstimator(size_t startup_phase_length_blocks, const EchoCanceller3Config& config, size_t num_capture_channels); ~ErleEstimator(); @@ -41,24 +42,29 @@ class ErleEstimator { void Reset(bool delay_change); // Updates the ERLE estimates. - void Update(const RenderBuffer& render_buffer, - const std::vector>& - filter_frequency_response, - rtc::ArrayView reverb_render_spectrum, - rtc::ArrayView capture_spectrum, - rtc::ArrayView subtractor_spectrum, - bool converged_filter, - bool onset_detection); + void Update( + const RenderBuffer& render_buffer, + rtc::ArrayView>> + filter_frequency_responses, + rtc::ArrayView + avg_render_spectrum_with_reverb, + rtc::ArrayView> + capture_spectra, + rtc::ArrayView> + subtractor_spectra, + const std::vector& converged_filters); // Returns the most recent subband ERLE estimates. rtc::ArrayView> Erle() const { - return use_signal_dependent_erle_ ? signal_dependent_erle_estimator_.Erle() - : subband_erle_estimator_.Erle(); + return signal_dependent_erle_estimator_ + ? signal_dependent_erle_estimator_->Erle() + : subband_erle_estimator_.Erle(); } - // Returns the subband ERLE that are estimated during onsets. Used - // for logging/testing. - rtc::ArrayView ErleOnsets() const { + // Returns the subband ERLE that are estimated during onsets (only used for + // testing). + rtc::ArrayView> ErleOnsets() + const { return subband_erle_estimator_.ErleOnsets(); } @@ -80,11 +86,11 @@ class ErleEstimator { void Dump(const std::unique_ptr& data_dumper) const; private: - const size_t startup_phase_length_blocks__; - const bool use_signal_dependent_erle_; + const size_t startup_phase_length_blocks_; FullBandErleEstimator fullband_erle_estimator_; SubbandErleEstimator subband_erle_estimator_; - SignalDependentErleEstimator signal_dependent_erle_estimator_; + std::unique_ptr + signal_dependent_erle_estimator_; size_t blocks_since_reset_ = 0; }; diff --git a/modules/audio_processing/aec3/erle_estimator_unittest.cc b/modules/audio_processing/aec3/erle_estimator_unittest.cc index e8f99bc44e..48a6d6cecd 100644 --- a/modules/audio_processing/aec3/erle_estimator_unittest.cc +++ b/modules/audio_processing/aec3/erle_estimator_unittest.cc @@ -27,21 +27,25 @@ constexpr float kTrueErle = 10.f; constexpr float kTrueErleOnsets = 1.0f; constexpr float kEchoPathGain = 3.f; -void VerifyErleBands(rtc::ArrayView erle, - float reference_lf, - float reference_hf) { - std::for_each( - erle.begin(), erle.begin() + kLowFrequencyLimit, - [reference_lf](float a) { EXPECT_NEAR(reference_lf, a, 0.001); }); - std::for_each( - erle.begin() + kLowFrequencyLimit, erle.end(), - [reference_hf](float a) { EXPECT_NEAR(reference_hf, a, 0.001); }); +void VerifyErleBands( + rtc::ArrayView> erle, + float reference_lf, + float reference_hf) { + for (size_t ch = 0; ch < erle.size(); ++ch) { + std::for_each( + erle[ch].begin(), erle[ch].begin() + kLowFrequencyLimit, + [reference_lf](float a) { EXPECT_NEAR(reference_lf, a, 0.001); }); + std::for_each( + erle[ch].begin() + kLowFrequencyLimit, erle[ch].end(), + [reference_hf](float a) { EXPECT_NEAR(reference_hf, a, 0.001); }); + } } -void VerifyErle(rtc::ArrayView erle, - float erle_time_domain, - float reference_lf, - float reference_hf) { +void VerifyErle( + rtc::ArrayView> erle, + float erle_time_domain, + float reference_lf, + float reference_hf) { VerifyErleBands(erle, reference_lf, reference_hf); EXPECT_NEAR(reference_lf, erle_time_domain, 0.5); } @@ -65,160 +69,210 @@ void FormFarendTimeFrame(std::vector>>* x) { } void FormFarendFrame(const RenderBuffer& render_buffer, + float erle, std::array* X2, - std::array* E2, - std::array* Y2, - float erle) { + rtc::ArrayView> E2, + rtc::ArrayView> Y2) { const auto& spectrum_buffer = render_buffer.GetSpectrumBuffer(); - const auto& X2_from_buffer = - spectrum_buffer.buffer[spectrum_buffer.write][/*channel=*/0]; - std::copy(X2_from_buffer.begin(), X2_from_buffer.end(), X2->begin()); - std::transform(X2->begin(), X2->end(), Y2->begin(), - [](float a) { return a * kEchoPathGain * kEchoPathGain; }); - std::transform(Y2->begin(), Y2->end(), E2->begin(), - [erle](float a) { return a / erle; }); + const int num_render_channels = spectrum_buffer.buffer[0].size(); + const int num_capture_channels = Y2.size(); -} // namespace - -void FormNearendFrame(std::vector>>* x, - std::array* X2, - std::array* E2, - std::array* Y2) { - for (size_t band = 0; band < x->size(); ++band) { - for (size_t channel = 0; channel < (*x)[band].size(); ++channel) { - std::fill((*x)[band][channel].begin(), (*x)[band][channel].end(), 0.f); - X2->fill(0.f); - Y2->fill(500.f * 1000.f * 1000.f); - E2->fill((*Y2)[0]); + X2->fill(0.f); + for (int ch = 0; ch < num_render_channels; ++ch) { + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + (*X2)[k] += spectrum_buffer.buffer[spectrum_buffer.write][ch][k] / + num_render_channels; } } + + for (int ch = 0; ch < num_capture_channels; ++ch) { + std::transform(X2->begin(), X2->end(), Y2[ch].begin(), + [](float a) { return a * kEchoPathGain * kEchoPathGain; }); + std::transform(Y2[ch].begin(), Y2[ch].end(), E2[ch].begin(), + [erle](float a) { return a / erle; }); + } } -void GetFilterFreq(std::vector>& - filter_frequency_response, - size_t delay_headroom_samples) { - const size_t delay_headroom_blocks = delay_headroom_samples / kBlockSize; - for (auto& block_freq_resp : filter_frequency_response) { - block_freq_resp.fill(0.f); +void FormNearendFrame( + std::vector>>* x, + std::array* X2, + rtc::ArrayView> E2, + rtc::ArrayView> Y2) { + for (size_t band = 0; band < x->size(); ++band) { + for (size_t ch = 0; ch < (*x)[band].size(); ++ch) { + std::fill((*x)[band][ch].begin(), (*x)[band][ch].end(), 0.f); + } } - for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { - filter_frequency_response[delay_headroom_blocks][k] = kEchoPathGain; + X2->fill(0.f); + for (size_t ch = 0; ch < Y2.size(); ++ch) { + Y2[ch].fill(500.f * 1000.f * 1000.f); + E2[ch].fill(Y2[ch][0]); + } +} + +void GetFilterFreq( + size_t delay_headroom_samples, + rtc::ArrayView>> + filter_frequency_response) { + const size_t delay_headroom_blocks = delay_headroom_samples / kBlockSize; + for (size_t ch = 0; ch < filter_frequency_response[0].size(); ++ch) { + for (auto& block_freq_resp : filter_frequency_response) { + block_freq_resp[ch].fill(0.f); + } + + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + filter_frequency_response[delay_headroom_blocks][ch][k] = kEchoPathGain; + } } } } // namespace TEST(ErleEstimator, VerifyErleIncreaseAndHold) { - std::array X2; - std::array E2; - std::array Y2; - constexpr size_t kNumRenderChannels = 1; - constexpr size_t kNumCaptureChannels = 1; constexpr int kSampleRateHz = 48000; constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); - EchoCanceller3Config config; - std::vector>> x( - kNumBands, std::vector>( - kNumRenderChannels, std::vector(kBlockSize, 0.f))); - std::vector> filter_frequency_response( - config.filter.main.length_blocks); - std::unique_ptr render_delay_buffer( - RenderDelayBuffer::Create(config, kSampleRateHz, kNumRenderChannels)); + for (size_t num_render_channels : {1, 2, 4, 8}) { + for (size_t num_capture_channels : {1, 2, 4}) { + std::array X2; + std::vector> E2( + num_capture_channels); + std::vector> Y2( + num_capture_channels); + std::vector converged_filters(num_capture_channels, true); - GetFilterFreq(filter_frequency_response, config.delay.delay_headroom_samples); + EchoCanceller3Config config; + config.erle.onset_detection = true; - ErleEstimator estimator(0, config, kNumCaptureChannels); + std::vector>> x( + kNumBands, + std::vector>(num_render_channels, + std::vector(kBlockSize, 0.f))); + std::vector>> + filter_frequency_response( + config.filter.main.length_blocks, + std::vector>( + num_capture_channels)); + std::unique_ptr render_delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, + num_render_channels)); - FormFarendTimeFrame(&x); - render_delay_buffer->Insert(x); - render_delay_buffer->PrepareCaptureProcessing(); - // Verifies that the ERLE estimate is properly increased to higher values. - FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), &X2, &E2, &Y2, - kTrueErle); - for (size_t k = 0; k < 200; ++k) { - render_delay_buffer->Insert(x); - render_delay_buffer->PrepareCaptureProcessing(); - estimator.Update(*render_delay_buffer->GetRenderBuffer(), - filter_frequency_response, X2, Y2, E2, true, true); + GetFilterFreq(config.delay.delay_headroom_samples, + filter_frequency_response); + + ErleEstimator estimator(0, config, num_capture_channels); + + FormFarendTimeFrame(&x); + render_delay_buffer->Insert(x); + render_delay_buffer->PrepareCaptureProcessing(); + // Verifies that the ERLE estimate is properly increased to higher values. + FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), kTrueErle, &X2, + E2, Y2); + for (size_t k = 0; k < 200; ++k) { + render_delay_buffer->Insert(x); + render_delay_buffer->PrepareCaptureProcessing(); + estimator.Update(*render_delay_buffer->GetRenderBuffer(), + filter_frequency_response, X2, Y2, E2, + converged_filters); + } + VerifyErle(estimator.Erle(), std::pow(2.f, estimator.FullbandErleLog2()), + config.erle.max_l, config.erle.max_h); + + FormNearendFrame(&x, &X2, E2, Y2); + // Verifies that the ERLE is not immediately decreased during nearend + // activity. + for (size_t k = 0; k < 50; ++k) { + render_delay_buffer->Insert(x); + render_delay_buffer->PrepareCaptureProcessing(); + estimator.Update(*render_delay_buffer->GetRenderBuffer(), + filter_frequency_response, X2, Y2, E2, + converged_filters); + } + VerifyErle(estimator.Erle(), std::pow(2.f, estimator.FullbandErleLog2()), + config.erle.max_l, config.erle.max_h); + } } - VerifyErle(estimator.Erle()[0], std::pow(2.f, estimator.FullbandErleLog2()), - config.erle.max_l, config.erle.max_h); - - FormNearendFrame(&x, &X2, &E2, &Y2); - // Verifies that the ERLE is not immediately decreased during nearend - // activity. - for (size_t k = 0; k < 50; ++k) { - render_delay_buffer->Insert(x); - render_delay_buffer->PrepareCaptureProcessing(); - estimator.Update(*render_delay_buffer->GetRenderBuffer(), - filter_frequency_response, X2, Y2, E2, true, true); - } - VerifyErle(estimator.Erle()[0], std::pow(2.f, estimator.FullbandErleLog2()), - config.erle.max_l, config.erle.max_h); } TEST(ErleEstimator, VerifyErleTrackingOnOnsets) { - constexpr size_t kNumRenderChannels = 1; - constexpr size_t kNumCaptureChannels = 1; constexpr int kSampleRateHz = 48000; constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); - std::array X2; - std::array E2; - std::array Y2; - EchoCanceller3Config config; - std::vector>> x( - kNumBands, std::vector>( - kNumRenderChannels, std::vector(kBlockSize, 0.f))); - std::vector> filter_frequency_response( - config.filter.main.length_blocks); - std::unique_ptr render_delay_buffer( - RenderDelayBuffer::Create(config, kSampleRateHz, kNumRenderChannels)); + for (size_t num_render_channels : {1, 2, 4, 8}) { + for (size_t num_capture_channels : {1, 2, 4}) { + std::array X2; + std::vector> E2( + num_capture_channels); + std::vector> Y2( + num_capture_channels); + std::vector converged_filters(num_capture_channels, true); + EchoCanceller3Config config; + config.erle.onset_detection = true; + std::vector>> x( + kNumBands, + std::vector>(num_render_channels, + std::vector(kBlockSize, 0.f))); + std::vector>> + filter_frequency_response( + config.filter.main.length_blocks, + std::vector>( + num_capture_channels)); + std::unique_ptr render_delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, + num_render_channels)); - GetFilterFreq(filter_frequency_response, config.delay.delay_headroom_samples); + GetFilterFreq(config.delay.delay_headroom_samples, + filter_frequency_response); - ErleEstimator estimator(0, config, kNumCaptureChannels); + ErleEstimator estimator(/*startup_phase_length_blocks=*/0, config, + num_capture_channels); - FormFarendTimeFrame(&x); - render_delay_buffer->Insert(x); - render_delay_buffer->PrepareCaptureProcessing(); - - for (size_t burst = 0; burst < 20; ++burst) { - FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), &X2, &E2, &Y2, - kTrueErleOnsets); - for (size_t k = 0; k < 10; ++k) { + FormFarendTimeFrame(&x); render_delay_buffer->Insert(x); render_delay_buffer->PrepareCaptureProcessing(); - estimator.Update(*render_delay_buffer->GetRenderBuffer(), - filter_frequency_response, X2, Y2, E2, true, true); - } - FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), &X2, &E2, &Y2, - kTrueErle); - for (size_t k = 0; k < 200; ++k) { - render_delay_buffer->Insert(x); - render_delay_buffer->PrepareCaptureProcessing(); - estimator.Update(*render_delay_buffer->GetRenderBuffer(), - filter_frequency_response, X2, Y2, E2, true, true); - } - FormNearendFrame(&x, &X2, &E2, &Y2); - for (size_t k = 0; k < 300; ++k) { - render_delay_buffer->Insert(x); - render_delay_buffer->PrepareCaptureProcessing(); - estimator.Update(*render_delay_buffer->GetRenderBuffer(), - filter_frequency_response, X2, Y2, E2, true, true); + + for (size_t burst = 0; burst < 20; ++burst) { + FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), + kTrueErleOnsets, &X2, E2, Y2); + for (size_t k = 0; k < 10; ++k) { + render_delay_buffer->Insert(x); + render_delay_buffer->PrepareCaptureProcessing(); + estimator.Update(*render_delay_buffer->GetRenderBuffer(), + filter_frequency_response, X2, Y2, E2, + converged_filters); + } + FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), kTrueErle, &X2, + E2, Y2); + for (size_t k = 0; k < 200; ++k) { + render_delay_buffer->Insert(x); + render_delay_buffer->PrepareCaptureProcessing(); + estimator.Update(*render_delay_buffer->GetRenderBuffer(), + filter_frequency_response, X2, Y2, E2, + converged_filters); + } + FormNearendFrame(&x, &X2, E2, Y2); + for (size_t k = 0; k < 300; ++k) { + render_delay_buffer->Insert(x); + render_delay_buffer->PrepareCaptureProcessing(); + estimator.Update(*render_delay_buffer->GetRenderBuffer(), + filter_frequency_response, X2, Y2, E2, + converged_filters); + } + } + VerifyErleBands(estimator.ErleOnsets(), config.erle.min, config.erle.min); + FormNearendFrame(&x, &X2, E2, Y2); + for (size_t k = 0; k < 1000; k++) { + estimator.Update(*render_delay_buffer->GetRenderBuffer(), + filter_frequency_response, X2, Y2, E2, + converged_filters); + } + // Verifies that during ne activity, Erle converges to the Erle for + // onsets. + VerifyErle(estimator.Erle(), std::pow(2.f, estimator.FullbandErleLog2()), + config.erle.min, config.erle.min); } } - VerifyErleBands(estimator.ErleOnsets(), config.erle.min, config.erle.min); - FormNearendFrame(&x, &X2, &E2, &Y2); - for (size_t k = 0; k < 1000; k++) { - estimator.Update(*render_delay_buffer->GetRenderBuffer(), - filter_frequency_response, X2, Y2, E2, true, true); - } - // Verifies that during ne activity, Erle converges to the Erle for onsets. - VerifyErle(estimator.Erle()[0], std::pow(2.f, estimator.FullbandErleLog2()), - config.erle.min, config.erle.min); } } // namespace webrtc diff --git a/modules/audio_processing/aec3/fullband_erle_estimator.cc b/modules/audio_processing/aec3/fullband_erle_estimator.cc index 086638d6b5..e421214821 100644 --- a/modules/audio_processing/aec3/fullband_erle_estimator.cc +++ b/modules/audio_processing/aec3/fullband_erle_estimator.cc @@ -35,7 +35,9 @@ FullBandErleEstimator::FullBandErleEstimator( size_t num_capture_channels) : min_erle_log2_(FastApproxLog2f(config.min + kEpsilon)), max_erle_lf_log2(FastApproxLog2f(config.max_l + kEpsilon)), - instantaneous_erle_(config), + hold_counters_time_domain_(num_capture_channels, 0), + erle_time_domain_log2_(num_capture_channels, min_erle_log2_), + instantaneous_erle_(num_capture_channels, ErleInstantaneous(config)), linear_filters_qualities_(num_capture_channels) { Reset(); } @@ -43,39 +45,49 @@ FullBandErleEstimator::FullBandErleEstimator( FullBandErleEstimator::~FullBandErleEstimator() = default; void FullBandErleEstimator::Reset() { - instantaneous_erle_.Reset(); + for (auto& instantaneous_erle_ch : instantaneous_erle_) { + instantaneous_erle_ch.Reset(); + } + UpdateQualityEstimates(); - erle_time_domain_log2_ = min_erle_log2_; - hold_counter_time_domain_ = 0; + std::fill(erle_time_domain_log2_.begin(), erle_time_domain_log2_.end(), + min_erle_log2_); + std::fill(hold_counters_time_domain_.begin(), + hold_counters_time_domain_.end(), 0); } -void FullBandErleEstimator::Update(rtc::ArrayView X2, - rtc::ArrayView Y2, - rtc::ArrayView E2, - bool converged_filter) { - if (converged_filter) { - // Computes the fullband ERLE. - const float X2_sum = std::accumulate(X2.begin(), X2.end(), 0.0f); - if (X2_sum > kX2BandEnergyThreshold * X2.size()) { - const float Y2_sum = std::accumulate(Y2.begin(), Y2.end(), 0.0f); - const float E2_sum = std::accumulate(E2.begin(), E2.end(), 0.0f); - if (instantaneous_erle_.Update(Y2_sum, E2_sum)) { - hold_counter_time_domain_ = kBlocksToHoldErle; - erle_time_domain_log2_ += - 0.1f * ((instantaneous_erle_.GetInstErleLog2().value()) - - erle_time_domain_log2_); - erle_time_domain_log2_ = rtc::SafeClamp( - erle_time_domain_log2_, min_erle_log2_, max_erle_lf_log2); +void FullBandErleEstimator::Update( + rtc::ArrayView X2, + rtc::ArrayView> Y2, + rtc::ArrayView> E2, + const std::vector& converged_filters) { + for (size_t ch = 0; ch < Y2.size(); ++ch) { + if (converged_filters[ch]) { + // Computes the fullband ERLE. + const float X2_sum = std::accumulate(X2.begin(), X2.end(), 0.0f); + if (X2_sum > kX2BandEnergyThreshold * X2.size()) { + const float Y2_sum = + std::accumulate(Y2[ch].begin(), Y2[ch].end(), 0.0f); + const float E2_sum = + std::accumulate(E2[ch].begin(), E2[ch].end(), 0.0f); + if (instantaneous_erle_[ch].Update(Y2_sum, E2_sum)) { + hold_counters_time_domain_[ch] = kBlocksToHoldErle; + erle_time_domain_log2_[ch] += + 0.1f * ((instantaneous_erle_[ch].GetInstErleLog2().value()) - + erle_time_domain_log2_[ch]); + erle_time_domain_log2_[ch] = rtc::SafeClamp( + erle_time_domain_log2_[ch], min_erle_log2_, max_erle_lf_log2); + } } } - } - --hold_counter_time_domain_; - if (hold_counter_time_domain_ <= 0) { - erle_time_domain_log2_ = - std::max(min_erle_log2_, erle_time_domain_log2_ - 0.044f); - } - if (hold_counter_time_domain_ == 0) { - instantaneous_erle_.ResetAccumulators(); + --hold_counters_time_domain_[ch]; + if (hold_counters_time_domain_[ch] <= 0) { + erle_time_domain_log2_[ch] = + std::max(min_erle_log2_, erle_time_domain_log2_[ch] - 0.044f); + } + if (hold_counters_time_domain_[ch] == 0) { + instantaneous_erle_[ch].ResetAccumulators(); + } } UpdateQualityEstimates(); @@ -84,12 +96,14 @@ void FullBandErleEstimator::Update(rtc::ArrayView X2, void FullBandErleEstimator::Dump( const std::unique_ptr& data_dumper) const { data_dumper->DumpRaw("aec3_fullband_erle_log2", FullbandErleLog2()); - instantaneous_erle_.Dump(data_dumper); + instantaneous_erle_[0].Dump(data_dumper); } void FullBandErleEstimator::UpdateQualityEstimates() { - std::fill(linear_filters_qualities_.begin(), linear_filters_qualities_.end(), - instantaneous_erle_.GetQualityEstimate()); + for (size_t ch = 0; ch < instantaneous_erle_.size(); ++ch) { + linear_filters_qualities_[ch] = + instantaneous_erle_[ch].GetQualityEstimate(); + } } FullBandErleEstimator::ErleInstantaneous::ErleInstantaneous( diff --git a/modules/audio_processing/aec3/fullband_erle_estimator.h b/modules/audio_processing/aec3/fullband_erle_estimator.h index 64372a2009..1580f1a8a5 100644 --- a/modules/audio_processing/aec3/fullband_erle_estimator.h +++ b/modules/audio_processing/aec3/fullband_erle_estimator.h @@ -17,6 +17,7 @@ #include "absl/types/optional.h" #include "api/array_view.h" #include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" #include "modules/audio_processing/logging/apm_data_dumper.h" namespace webrtc { @@ -33,12 +34,18 @@ class FullBandErleEstimator { // Updates the ERLE estimator. void Update(rtc::ArrayView X2, - rtc::ArrayView Y2, - rtc::ArrayView E2, - bool converged_filter); + rtc::ArrayView> Y2, + rtc::ArrayView> E2, + const std::vector& converged_filters); // Returns the fullband ERLE estimates in log2 units. - float FullbandErleLog2() const { return erle_time_domain_log2_; } + float FullbandErleLog2() const { + float min_erle = erle_time_domain_log2_[0]; + for (size_t ch = 1; ch < erle_time_domain_log2_.size(); ++ch) { + min_erle = std::min(min_erle, erle_time_domain_log2_[ch]); + } + return min_erle; + } // Returns an estimation of the current linear filter quality. It returns a // float number between 0 and 1 mapping 1 to the highest possible quality. @@ -98,11 +105,11 @@ class FullBandErleEstimator { int num_points_; }; - int hold_counter_time_domain_; - float erle_time_domain_log2_; const float min_erle_log2_; const float max_erle_lf_log2; - ErleInstantaneous instantaneous_erle_; + std::vector hold_counters_time_domain_; + std::vector erle_time_domain_log2_; + std::vector instantaneous_erle_; std::vector> linear_filters_qualities_; }; diff --git a/modules/audio_processing/aec3/signal_dependent_erle_estimator.cc b/modules/audio_processing/aec3/signal_dependent_erle_estimator.cc index d3c07a1bf1..d99b7f3e25 100644 --- a/modules/audio_processing/aec3/signal_dependent_erle_estimator.cc +++ b/modules/audio_processing/aec3/signal_dependent_erle_estimator.cc @@ -132,29 +132,38 @@ SignalDependentErleEstimator::SignalDependentErleEstimator( num_blocks_, num_sections_)), erle_(num_capture_channels), - S2_section_accum_(num_sections_), - erle_estimators_(num_sections_), - correction_factors_(num_sections_) { + S2_section_accum_( + num_capture_channels, + std::vector>(num_sections_)), + erle_estimators_( + num_capture_channels, + std::vector>(num_sections_)), + erle_ref_(num_capture_channels), + correction_factors_( + num_capture_channels, + std::vector>(num_sections_)), + num_updates_(num_capture_channels), + n_active_sections_(num_capture_channels) { RTC_DCHECK_LE(num_sections_, num_blocks_); RTC_DCHECK_GE(num_sections_, 1); - Reset(); } SignalDependentErleEstimator::~SignalDependentErleEstimator() = default; void SignalDependentErleEstimator::Reset() { - for (auto& erle : erle_) { - erle.fill(min_erle_); + for (size_t ch = 0; ch < erle_.size(); ++ch) { + erle_[ch].fill(min_erle_); + for (auto& erle_estimator : erle_estimators_[ch]) { + erle_estimator.fill(min_erle_); + } + erle_ref_[ch].fill(min_erle_); + for (auto& factor : correction_factors_[ch]) { + factor.fill(1.0f); + } + num_updates_[ch].fill(0); + n_active_sections_[ch].fill(0); } - for (auto& erle_estimator : erle_estimators_) { - erle_estimator.fill(min_erle_); - } - erle_ref_.fill(min_erle_); - for (auto& factor : correction_factors_) { - factor.fill(1.0f); - } - num_updates_.fill(0); } // Updates the Erle estimate by analyzing the current input signals. It takes @@ -165,44 +174,45 @@ void SignalDependentErleEstimator::Reset() { // correction factor to the erle that is given as an input to this method. void SignalDependentErleEstimator::Update( const RenderBuffer& render_buffer, - const std::vector>& - filter_frequency_response, - rtc::ArrayView X2, - rtc::ArrayView Y2, - rtc::ArrayView E2, + rtc::ArrayView>> + filter_frequency_responses, + rtc::ArrayView X2, + rtc::ArrayView> Y2, + rtc::ArrayView> E2, rtc::ArrayView> average_erle, - bool converged_filter) { + const std::vector& converged_filters) { RTC_DCHECK_GT(num_sections_, 1); // Gets the number of filter sections that are needed for achieving 90 % // of the power spectrum energy of the echo estimate. - std::array n_active_sections; - ComputeNumberOfActiveFilterSections(render_buffer, filter_frequency_response, - n_active_sections); + ComputeNumberOfActiveFilterSections(render_buffer, + filter_frequency_responses); - if (converged_filter) { - // Updates the correction factor that is used for correcting the erle and - // adapt it to the particular characteristics of the input signal. - UpdateCorrectionFactors(X2, Y2, E2, n_active_sections); - } + // Updates the correction factors that is used for correcting the erle and + // adapt it to the particular characteristics of the input signal. + UpdateCorrectionFactors(X2, Y2, E2, converged_filters); // Applies the correction factor to the input erle for getting a more refined // erle estimation for the current input signal. - for (size_t k = 0; k < kFftLengthBy2; ++k) { - float correction_factor = - correction_factors_[n_active_sections[k]][band_to_subband_[k]]; - erle_[0][k] = rtc::SafeClamp(average_erle[0][k] * correction_factor, - min_erle_, max_erle_[band_to_subband_[k]]); + for (size_t ch = 0; ch < erle_.size(); ++ch) { + for (size_t k = 0; k < kFftLengthBy2; ++k) { + RTC_DCHECK_GT(correction_factors_[ch].size(), n_active_sections_[ch][k]); + float correction_factor = + correction_factors_[ch][n_active_sections_[ch][k]] + [band_to_subband_[k]]; + erle_[ch][k] = rtc::SafeClamp(average_erle[ch][k] * correction_factor, + min_erle_, max_erle_[band_to_subband_[k]]); + } } } void SignalDependentErleEstimator::Dump( const std::unique_ptr& data_dumper) const { - for (auto& erle : erle_estimators_) { + for (auto& erle : erle_estimators_[0]) { data_dumper->DumpRaw("aec3_all_erle", erle); } - data_dumper->DumpRaw("aec3_ref_erle", erle_ref_); - for (auto& factor : correction_factors_) { + data_dumper->DumpRaw("aec3_ref_erle", erle_ref_[0]); + for (auto& factor : correction_factors_[0]) { data_dumper->DumpRaw("aec3_erle_correction_factor", factor); } } @@ -211,163 +221,185 @@ void SignalDependentErleEstimator::Dump( // together constitute 90% of the estimated echo energy. void SignalDependentErleEstimator::ComputeNumberOfActiveFilterSections( const RenderBuffer& render_buffer, - const std::vector>& - filter_frequency_response, - rtc::ArrayView n_active_filter_sections) { + rtc::ArrayView>> + filter_frequency_responses) { RTC_DCHECK_GT(num_sections_, 1); // Computes an approximation of the power spectrum if the filter would have // been limited to a certain number of filter sections. - ComputeEchoEstimatePerFilterSection(render_buffer, filter_frequency_response); + ComputeEchoEstimatePerFilterSection(render_buffer, + filter_frequency_responses); // For each band, computes the number of filter sections that are needed for // achieving the 90 % energy in the echo estimate. - ComputeActiveFilterSections(n_active_filter_sections); + ComputeActiveFilterSections(); } void SignalDependentErleEstimator::UpdateCorrectionFactors( - rtc::ArrayView X2, - rtc::ArrayView Y2, - rtc::ArrayView E2, - rtc::ArrayView n_active_sections) { - constexpr float kX2BandEnergyThreshold = 44015068.0f; - constexpr float kSmthConstantDecreases = 0.1f; - constexpr float kSmthConstantIncreases = kSmthConstantDecreases / 2.f; - auto subband_powers = [](rtc::ArrayView power_spectrum, - rtc::ArrayView power_spectrum_subbands) { - for (size_t subband = 0; subband < kSubbands; ++subband) { - RTC_DCHECK_LE(kBandBoundaries[subband + 1], power_spectrum.size()); - power_spectrum_subbands[subband] = std::accumulate( - power_spectrum.begin() + kBandBoundaries[subband], - power_spectrum.begin() + kBandBoundaries[subband + 1], 0.f); - } - }; + rtc::ArrayView X2, + rtc::ArrayView> Y2, + rtc::ArrayView> E2, + const std::vector& converged_filters) { + for (size_t ch = 0; ch < converged_filters.size(); ++ch) { + if (converged_filters[ch]) { + constexpr float kX2BandEnergyThreshold = 44015068.0f; + constexpr float kSmthConstantDecreases = 0.1f; + constexpr float kSmthConstantIncreases = kSmthConstantDecreases / 2.f; + auto subband_powers = [](rtc::ArrayView power_spectrum, + rtc::ArrayView power_spectrum_subbands) { + for (size_t subband = 0; subband < kSubbands; ++subband) { + RTC_DCHECK_LE(kBandBoundaries[subband + 1], power_spectrum.size()); + power_spectrum_subbands[subband] = std::accumulate( + power_spectrum.begin() + kBandBoundaries[subband], + power_spectrum.begin() + kBandBoundaries[subband + 1], 0.f); + } + }; - std::array X2_subbands, E2_subbands, Y2_subbands; - subband_powers(X2, X2_subbands); - subband_powers(E2, E2_subbands); - subband_powers(Y2, Y2_subbands); - std::array idx_subbands; - for (size_t subband = 0; subband < kSubbands; ++subband) { - // When aggregating the number of active sections in the filter for - // different bands we choose to take the minimum of all of them. As an - // example, if for one of the bands it is the direct path its main - // contributor to the final echo estimate, we consider the direct path is - // as well the main contributor for the subband that contains that - // particular band. That aggregate number of sections will be later used as - // the identifier of the erle estimator that needs to be updated. - RTC_DCHECK_LE(kBandBoundaries[subband + 1], n_active_sections.size()); - idx_subbands[subband] = *std::min_element( - n_active_sections.begin() + kBandBoundaries[subband], - n_active_sections.begin() + kBandBoundaries[subband + 1]); - } + std::array X2_subbands, E2_subbands, Y2_subbands; + subband_powers(X2, X2_subbands); + subband_powers(E2[ch], E2_subbands); + subband_powers(Y2[ch], Y2_subbands); + std::array idx_subbands; + for (size_t subband = 0; subband < kSubbands; ++subband) { + // When aggregating the number of active sections in the filter for + // different bands we choose to take the minimum of all of them. As an + // example, if for one of the bands it is the direct path its main + // contributor to the final echo estimate, we consider the direct path + // is as well the main contributor for the subband that contains that + // particular band. That aggregate number of sections will be later used + // as the identifier of the erle estimator that needs to be updated. + RTC_DCHECK_LE(kBandBoundaries[subband + 1], + n_active_sections_[ch].size()); + idx_subbands[subband] = *std::min_element( + n_active_sections_[ch].begin() + kBandBoundaries[subband], + n_active_sections_[ch].begin() + kBandBoundaries[subband + 1]); + } - std::array new_erle; - std::array is_erle_updated; - is_erle_updated.fill(false); - new_erle.fill(0.f); - for (size_t subband = 0; subband < kSubbands; ++subband) { - if (X2_subbands[subband] > kX2BandEnergyThreshold && - E2_subbands[subband] > 0) { - new_erle[subband] = Y2_subbands[subband] / E2_subbands[subband]; - RTC_DCHECK_GT(new_erle[subband], 0); - is_erle_updated[subband] = true; - ++num_updates_[subband]; - } - } + std::array new_erle; + std::array is_erle_updated; + is_erle_updated.fill(false); + new_erle.fill(0.f); + for (size_t subband = 0; subband < kSubbands; ++subband) { + if (X2_subbands[subband] > kX2BandEnergyThreshold && + E2_subbands[subband] > 0) { + new_erle[subband] = Y2_subbands[subband] / E2_subbands[subband]; + RTC_DCHECK_GT(new_erle[subband], 0); + is_erle_updated[subband] = true; + ++num_updates_[ch][subband]; + } + } - for (size_t subband = 0; subband < kSubbands; ++subband) { - const size_t idx = idx_subbands[subband]; - RTC_DCHECK_LT(idx, erle_estimators_.size()); - float alpha = new_erle[subband] > erle_estimators_[idx][subband] - ? kSmthConstantIncreases - : kSmthConstantDecreases; - alpha = static_cast(is_erle_updated[subband]) * alpha; - erle_estimators_[idx][subband] += - alpha * (new_erle[subband] - erle_estimators_[idx][subband]); - erle_estimators_[idx][subband] = rtc::SafeClamp( - erle_estimators_[idx][subband], min_erle_, max_erle_[subband]); - } + for (size_t subband = 0; subband < kSubbands; ++subband) { + const size_t idx = idx_subbands[subband]; + RTC_DCHECK_LT(idx, erle_estimators_[ch].size()); + float alpha = new_erle[subband] > erle_estimators_[ch][idx][subband] + ? kSmthConstantIncreases + : kSmthConstantDecreases; + alpha = static_cast(is_erle_updated[subband]) * alpha; + erle_estimators_[ch][idx][subband] += + alpha * (new_erle[subband] - erle_estimators_[ch][idx][subband]); + erle_estimators_[ch][idx][subband] = rtc::SafeClamp( + erle_estimators_[ch][idx][subband], min_erle_, max_erle_[subband]); + } - for (size_t subband = 0; subband < kSubbands; ++subband) { - float alpha = new_erle[subband] > erle_ref_[subband] - ? kSmthConstantIncreases - : kSmthConstantDecreases; - alpha = static_cast(is_erle_updated[subband]) * alpha; - erle_ref_[subband] += alpha * (new_erle[subband] - erle_ref_[subband]); - erle_ref_[subband] = - rtc::SafeClamp(erle_ref_[subband], min_erle_, max_erle_[subband]); - } + for (size_t subband = 0; subband < kSubbands; ++subband) { + float alpha = new_erle[subband] > erle_ref_[ch][subband] + ? kSmthConstantIncreases + : kSmthConstantDecreases; + alpha = static_cast(is_erle_updated[subband]) * alpha; + erle_ref_[ch][subband] += + alpha * (new_erle[subband] - erle_ref_[ch][subband]); + erle_ref_[ch][subband] = rtc::SafeClamp(erle_ref_[ch][subband], + min_erle_, max_erle_[subband]); + } - for (size_t subband = 0; subband < kSubbands; ++subband) { - constexpr int kNumUpdateThr = 50; - if (is_erle_updated[subband] && num_updates_[subband] > kNumUpdateThr) { - const size_t idx = idx_subbands[subband]; - RTC_DCHECK_GT(erle_ref_[subband], 0.f); - // Computes the ratio between the erle that is updated using all the - // points and the erle that is updated only on signals that share the - // same number of active filter sections. - float new_correction_factor = - erle_estimators_[idx][subband] / erle_ref_[subband]; + for (size_t subband = 0; subband < kSubbands; ++subband) { + constexpr int kNumUpdateThr = 50; + if (is_erle_updated[subband] && + num_updates_[ch][subband] > kNumUpdateThr) { + const size_t idx = idx_subbands[subband]; + RTC_DCHECK_GT(erle_ref_[ch][subband], 0.f); + // Computes the ratio between the erle that is updated using all the + // points and the erle that is updated only on signals that share the + // same number of active filter sections. + float new_correction_factor = + erle_estimators_[ch][idx][subband] / erle_ref_[ch][subband]; - correction_factors_[idx][subband] += - 0.1f * (new_correction_factor - correction_factors_[idx][subband]); + correction_factors_[ch][idx][subband] += + 0.1f * + (new_correction_factor - correction_factors_[ch][idx][subband]); + } + } } } } void SignalDependentErleEstimator::ComputeEchoEstimatePerFilterSection( const RenderBuffer& render_buffer, - const std::vector>& - filter_frequency_response) { + rtc::ArrayView>> + filter_frequency_responses) { const SpectrumBuffer& spectrum_render_buffer = render_buffer.GetSpectrumBuffer(); + const size_t num_render_channels = spectrum_render_buffer.buffer[0].size(); + const size_t num_capture_channels = S2_section_accum_.size(); + const float one_by_num_render_channels = 1.f / num_render_channels; - RTC_DCHECK_EQ(S2_section_accum_.size() + 1, - section_boundaries_blocks_.size()); - size_t idx_render = render_buffer.Position(); - idx_render = spectrum_render_buffer.OffsetIndex( - idx_render, section_boundaries_blocks_[0]); + RTC_DCHECK_EQ(S2_section_accum_.size(), filter_frequency_responses.size()); - for (size_t section = 0; section < num_sections_; ++section) { - std::array X2_section; - std::array H2_section; - X2_section.fill(0.f); - H2_section.fill(0.f); - const size_t block_limit = std::min(section_boundaries_blocks_[section + 1], - filter_frequency_response.size()); - for (size_t block = section_boundaries_blocks_[section]; - block < block_limit; ++block) { - std::transform( - X2_section.begin(), X2_section.end(), - spectrum_render_buffer.buffer[idx_render][/*channel=*/0].begin(), - X2_section.begin(), std::plus()); - std::transform(H2_section.begin(), H2_section.end(), - filter_frequency_response[block].begin(), - H2_section.begin(), std::plus()); - idx_render = spectrum_render_buffer.IncIndex(idx_render); + for (size_t capture_ch = 0; capture_ch < num_capture_channels; ++capture_ch) { + RTC_DCHECK_EQ(S2_section_accum_[capture_ch].size() + 1, + section_boundaries_blocks_.size()); + size_t idx_render = render_buffer.Position(); + idx_render = spectrum_render_buffer.OffsetIndex( + idx_render, section_boundaries_blocks_[0]); + + for (size_t section = 0; section < num_sections_; ++section) { + std::array X2_section; + std::array H2_section; + X2_section.fill(0.f); + H2_section.fill(0.f); + const size_t block_limit = + std::min(section_boundaries_blocks_[section + 1], + filter_frequency_responses[capture_ch].size()); + for (size_t block = section_boundaries_blocks_[section]; + block < block_limit; ++block) { + for (size_t render_ch = 0; + render_ch < spectrum_render_buffer.buffer[idx_render].size(); + ++render_ch) { + for (size_t k = 0; k < X2_section.size(); ++k) { + X2_section[k] += + spectrum_render_buffer.buffer[idx_render][render_ch][k] * + one_by_num_render_channels; + } + } + std::transform(H2_section.begin(), H2_section.end(), + filter_frequency_responses[capture_ch][block].begin(), + H2_section.begin(), std::plus()); + idx_render = spectrum_render_buffer.IncIndex(idx_render); + } + + std::transform(X2_section.begin(), X2_section.end(), H2_section.begin(), + S2_section_accum_[capture_ch][section].begin(), + std::multiplies()); } - std::transform(X2_section.begin(), X2_section.end(), H2_section.begin(), - S2_section_accum_[section].begin(), - std::multiplies()); - } - - for (size_t section = 1; section < num_sections_; ++section) { - std::transform(S2_section_accum_[section - 1].begin(), - S2_section_accum_[section - 1].end(), - S2_section_accum_[section].begin(), - S2_section_accum_[section].begin(), std::plus()); + for (size_t section = 1; section < num_sections_; ++section) { + std::transform(S2_section_accum_[capture_ch][section - 1].begin(), + S2_section_accum_[capture_ch][section - 1].end(), + S2_section_accum_[capture_ch][section].begin(), + S2_section_accum_[capture_ch][section].begin(), + std::plus()); + } } } -void SignalDependentErleEstimator::ComputeActiveFilterSections( - rtc::ArrayView number_active_filter_sections) const { - std::fill(number_active_filter_sections.begin(), - number_active_filter_sections.end(), 0); - for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { - size_t section = num_sections_; - float target = 0.9f * S2_section_accum_[num_sections_ - 1][k]; - while (section > 0 && S2_section_accum_[section - 1][k] >= target) { - number_active_filter_sections[k] = --section; +void SignalDependentErleEstimator::ComputeActiveFilterSections() { + for (size_t ch = 0; ch < n_active_sections_.size(); ++ch) { + std::fill(n_active_sections_[ch].begin(), n_active_sections_[ch].end(), 0); + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + size_t section = num_sections_; + float target = 0.9f * S2_section_accum_[ch][num_sections_ - 1][k]; + while (section > 0 && S2_section_accum_[ch][section - 1][k] >= target) { + n_active_sections_[ch][k] = --section; + } } } } diff --git a/modules/audio_processing/aec3/signal_dependent_erle_estimator.h b/modules/audio_processing/aec3/signal_dependent_erle_estimator.h index da0b8ab61a..498e922f13 100644 --- a/modules/audio_processing/aec3/signal_dependent_erle_estimator.h +++ b/modules/audio_processing/aec3/signal_dependent_erle_estimator.h @@ -45,13 +45,13 @@ class SignalDependentErleEstimator { // to be an estimation of the average Erle achieved by the linear filter. void Update( const RenderBuffer& render_buffer, - const std::vector>& + rtc::ArrayView>> filter_frequency_response, - rtc::ArrayView X2, - rtc::ArrayView Y2, - rtc::ArrayView E2, + rtc::ArrayView X2, + rtc::ArrayView> Y2, + rtc::ArrayView> E2, rtc::ArrayView> average_erle, - bool converged_filter); + const std::vector& converged_filters); void Dump(const std::unique_ptr& data_dumper) const; @@ -60,22 +60,21 @@ class SignalDependentErleEstimator { private: void ComputeNumberOfActiveFilterSections( const RenderBuffer& render_buffer, - const std::vector>& - filter_frequency_response, - rtc::ArrayView n_active_filter_sections); + rtc::ArrayView>> + filter_frequency_responses); - void UpdateCorrectionFactors(rtc::ArrayView X2, - rtc::ArrayView Y2, - rtc::ArrayView E2, - rtc::ArrayView n_active_sections); + void UpdateCorrectionFactors( + rtc::ArrayView X2, + rtc::ArrayView> Y2, + rtc::ArrayView> E2, + const std::vector& converged_filters); void ComputeEchoEstimatePerFilterSection( const RenderBuffer& render_buffer, - const std::vector>& - filter_frequency_response); + rtc::ArrayView>> + filter_frequency_responses); - void ComputeActiveFilterSections( - rtc::ArrayView number_active_filter_sections) const; + void ComputeActiveFilterSections(); const float min_erle_; const size_t num_sections_; @@ -85,11 +84,13 @@ class SignalDependentErleEstimator { const std::array max_erle_; const std::vector section_boundaries_blocks_; std::vector> erle_; - std::vector> S2_section_accum_; - std::vector> erle_estimators_; - std::array erle_ref_; - std::vector> correction_factors_; - std::array num_updates_; + std::vector>> + S2_section_accum_; + std::vector>> erle_estimators_; + std::vector> erle_ref_; + std::vector>> correction_factors_; + std::vector> num_updates_; + std::vector> n_active_sections_; }; } // namespace webrtc diff --git a/modules/audio_processing/aec3/signal_dependent_erle_estimator_unittest.cc b/modules/audio_processing/aec3/signal_dependent_erle_estimator_unittest.cc index ccc2ef3455..394310d9e2 100644 --- a/modules/audio_processing/aec3/signal_dependent_erle_estimator_unittest.cc +++ b/modules/audio_processing/aec3/signal_dependent_erle_estimator_unittest.cc @@ -44,13 +44,25 @@ void GetActiveFrame(std::vector>>* x) { class TestInputs { public: - explicit TestInputs(const EchoCanceller3Config& cfg); + TestInputs(const EchoCanceller3Config& cfg, + size_t num_render_channels, + size_t num_capture_channels); ~TestInputs(); const RenderBuffer& GetRenderBuffer() { return *render_buffer_; } - rtc::ArrayView GetX2() { return X2_; } - rtc::ArrayView GetY2() { return Y2_; } - rtc::ArrayView GetE2() { return E2_; } - std::vector> GetH2() { return H2_; } + rtc::ArrayView GetX2() { return X2_; } + rtc::ArrayView> GetY2() const { + return Y2_; + } + rtc::ArrayView> GetE2() const { + return E2_; + } + rtc::ArrayView>> + GetH2() const { + return H2_; + } + const std::vector& GetConvergedFilters() const { + return converged_filters_; + } void Update(); private: @@ -59,24 +71,37 @@ class TestInputs { std::unique_ptr render_delay_buffer_; RenderBuffer* render_buffer_; std::array X2_; - std::array Y2_; - std::array E2_; - std::vector> H2_; + std::vector> Y2_; + std::vector> E2_; + std::vector>> H2_; std::vector>> x_; + std::vector converged_filters_; }; -TestInputs::TestInputs(const EchoCanceller3Config& cfg) - : render_delay_buffer_(RenderDelayBuffer::Create(cfg, 16000, 1)), - H2_(cfg.filter.main.length_blocks), +TestInputs::TestInputs(const EchoCanceller3Config& cfg, + size_t num_render_channels, + size_t num_capture_channels) + : render_delay_buffer_( + RenderDelayBuffer::Create(cfg, 16000, num_render_channels)), + Y2_(num_capture_channels), + E2_(num_capture_channels), + H2_(num_capture_channels, + std::vector>( + cfg.filter.main.length_blocks)), x_(1, - std::vector>(1, - std::vector(kBlockSize, 0.f))) { + std::vector>(num_render_channels, + std::vector(kBlockSize, 0.f))), + converged_filters_(num_capture_channels, true) { render_delay_buffer_->AlignFromDelay(4); render_buffer_ = render_delay_buffer_->GetRenderBuffer(); - for (auto& H : H2_) { - H.fill(0.f); + for (auto& H2_ch : H2_) { + for (auto& H2_p : H2_ch) { + H2_p.fill(0.f); + } + } + for (auto& H2_p : H2_[0]) { + H2_p.fill(1.f); } - H2_[0].fill(1.0f); } TestInputs::~TestInputs() = default; @@ -102,40 +127,47 @@ void TestInputs::UpdateCurrentPowerSpectra() { auto& X2 = spectrum_render_buffer.buffer[idx][/*channel=*/0]; auto& X2_prev = spectrum_render_buffer.buffer[prev_idx][/*channel=*/0]; std::copy(X2.begin(), X2.end(), X2_.begin()); - RTC_DCHECK_EQ(X2.size(), Y2_.size()); - for (size_t k = 0; k < X2.size(); ++k) { - E2_[k] = 0.01f * X2_prev[k]; - Y2_[k] = X2[k] + E2_[k]; + for (size_t ch = 0; ch < Y2_.size(); ++ch) { + RTC_DCHECK_EQ(X2.size(), Y2_[ch].size()); + for (size_t k = 0; k < X2.size(); ++k) { + E2_[ch][k] = 0.01f * X2_prev[k]; + Y2_[ch][k] = X2[k] + E2_[ch][k]; + } } } } // namespace TEST(SignalDependentErleEstimator, SweepSettings) { - const size_t kNumCaptureChannels = 1; - EchoCanceller3Config cfg; - size_t max_length_blocks = 50; - for (size_t blocks = 0; blocks < max_length_blocks; blocks = blocks + 10) { - for (size_t delay_headroom = 0; delay_headroom < 5; ++delay_headroom) { - for (size_t num_sections = 2; num_sections < max_length_blocks; - ++num_sections) { - cfg.filter.main.length_blocks = blocks; - cfg.filter.main_initial.length_blocks = - std::min(cfg.filter.main_initial.length_blocks, blocks); - cfg.delay.delay_headroom_samples = delay_headroom * kBlockSize; - cfg.erle.num_sections = num_sections; - if (EchoCanceller3Config::Validate(&cfg)) { - SignalDependentErleEstimator s(cfg, kNumCaptureChannels); - std::array, kNumCaptureChannels> - average_erle; - for (auto& e : average_erle) { - e.fill(cfg.erle.max_l); - } - TestInputs inputs(cfg); - for (size_t n = 0; n < 10; ++n) { - inputs.Update(); - s.Update(inputs.GetRenderBuffer(), inputs.GetH2(), inputs.GetX2(), - inputs.GetY2(), inputs.GetE2(), average_erle, true); + for (size_t num_render_channels : {1, 2, 4}) { + for (size_t num_capture_channels : {1, 2, 4}) { + EchoCanceller3Config cfg; + size_t max_length_blocks = 50; + for (size_t blocks = 0; blocks < max_length_blocks; + blocks = blocks + 10) { + for (size_t delay_headroom = 0; delay_headroom < 5; ++delay_headroom) { + for (size_t num_sections = 2; num_sections < max_length_blocks; + ++num_sections) { + cfg.filter.main.length_blocks = blocks; + cfg.filter.main_initial.length_blocks = + std::min(cfg.filter.main_initial.length_blocks, blocks); + cfg.delay.delay_headroom_samples = delay_headroom * kBlockSize; + cfg.erle.num_sections = num_sections; + if (EchoCanceller3Config::Validate(&cfg)) { + SignalDependentErleEstimator s(cfg, num_capture_channels); + std::vector> average_erle( + num_capture_channels); + for (auto& e : average_erle) { + e.fill(cfg.erle.max_l); + } + TestInputs inputs(cfg, num_render_channels, num_capture_channels); + for (size_t n = 0; n < 10; ++n) { + inputs.Update(); + s.Update(inputs.GetRenderBuffer(), inputs.GetH2(), + inputs.GetX2(), inputs.GetY2(), inputs.GetE2(), + average_erle, inputs.GetConvergedFilters()); + } + } } } } @@ -144,25 +176,29 @@ TEST(SignalDependentErleEstimator, SweepSettings) { } TEST(SignalDependentErleEstimator, LongerRun) { - const size_t kNumCaptureChannels = 1; - EchoCanceller3Config cfg; - cfg.filter.main.length_blocks = 2; - cfg.filter.main_initial.length_blocks = 1; - cfg.delay.delay_headroom_samples = 0; - cfg.delay.hysteresis_limit_blocks = 0; - cfg.erle.num_sections = 2; - EXPECT_EQ(EchoCanceller3Config::Validate(&cfg), true); - std::array, kNumCaptureChannels> - average_erle; - for (auto& e : average_erle) { - e.fill(cfg.erle.max_l); - } - SignalDependentErleEstimator s(cfg, kNumCaptureChannels); - TestInputs inputs(cfg); - for (size_t n = 0; n < 200; ++n) { - inputs.Update(); - s.Update(inputs.GetRenderBuffer(), inputs.GetH2(), inputs.GetX2(), - inputs.GetY2(), inputs.GetE2(), average_erle, true); + for (size_t num_render_channels : {1, 2, 4}) { + for (size_t num_capture_channels : {1, 2, 4}) { + EchoCanceller3Config cfg; + cfg.filter.main.length_blocks = 2; + cfg.filter.main_initial.length_blocks = 1; + cfg.delay.delay_headroom_samples = 0; + cfg.delay.hysteresis_limit_blocks = 0; + cfg.erle.num_sections = 2; + EXPECT_EQ(EchoCanceller3Config::Validate(&cfg), true); + std::vector> average_erle( + num_capture_channels); + for (auto& e : average_erle) { + e.fill(cfg.erle.max_l); + } + SignalDependentErleEstimator s(cfg, num_capture_channels); + TestInputs inputs(cfg, num_render_channels, num_capture_channels); + for (size_t n = 0; n < 200; ++n) { + inputs.Update(); + s.Update(inputs.GetRenderBuffer(), inputs.GetH2(), inputs.GetX2(), + inputs.GetY2(), inputs.GetE2(), average_erle, + inputs.GetConvergedFilters()); + } + } } } diff --git a/modules/audio_processing/aec3/subband_erle_estimator.cc b/modules/audio_processing/aec3/subband_erle_estimator.cc index 137b0558fd..6c00091266 100644 --- a/modules/audio_processing/aec3/subband_erle_estimator.cc +++ b/modules/audio_processing/aec3/subband_erle_estimator.cc @@ -42,10 +42,15 @@ bool EnableMinErleDuringOnsets() { SubbandErleEstimator::SubbandErleEstimator(const EchoCanceller3Config& config, size_t num_capture_channels) - : min_erle_(config.erle.min), + : use_onset_detection_(config.erle.onset_detection), + min_erle_(config.erle.min), max_erle_(SetMaxErleBands(config.erle.max_l, config.erle.max_h)), use_min_erle_during_onsets_(EnableMinErleDuringOnsets()), - erle_(num_capture_channels) { + accum_spectra_(num_capture_channels), + erle_(num_capture_channels), + erle_onsets_(num_capture_channels), + coming_onset_(num_capture_channels), + hold_counters_(num_capture_channels) { Reset(); } @@ -55,26 +60,23 @@ void SubbandErleEstimator::Reset() { for (auto& erle : erle_) { erle.fill(min_erle_); } - erle_onsets_.fill(min_erle_); - coming_onset_.fill(true); - hold_counters_.fill(0); + for (size_t ch = 0; ch < erle_onsets_.size(); ++ch) { + erle_onsets_[ch].fill(min_erle_); + coming_onset_[ch].fill(true); + hold_counters_[ch].fill(0); + } ResetAccumulatedSpectra(); } -void SubbandErleEstimator::Update(rtc::ArrayView X2, - rtc::ArrayView Y2, - rtc::ArrayView E2, - bool converged_filter, - bool onset_detection) { - if (converged_filter) { - // Note that the use of the converged_filter flag already imposed - // a minimum of the erle that can be estimated as that flag would - // be false if the filter is performing poorly. - UpdateAccumulatedSpectra(X2, Y2, E2); - UpdateBands(onset_detection); - } +void SubbandErleEstimator::Update( + rtc::ArrayView X2, + rtc::ArrayView> Y2, + rtc::ArrayView> E2, + const std::vector& converged_filters) { + UpdateAccumulatedSpectra(X2, Y2, E2, converged_filters); + UpdateBands(converged_filters); - if (onset_detection) { + if (use_onset_detection_) { DecreaseErlePerBandForLowRenderSignals(); } @@ -86,97 +88,129 @@ void SubbandErleEstimator::Update(rtc::ArrayView X2, void SubbandErleEstimator::Dump( const std::unique_ptr& data_dumper) const { - data_dumper->DumpRaw("aec3_erle_onset", ErleOnsets()); + data_dumper->DumpRaw("aec3_erle_onset", ErleOnsets()[0]); } -void SubbandErleEstimator::UpdateBands(bool onset_detection) { - std::array new_erle; - std::array is_erle_updated; - is_erle_updated.fill(false); - - for (size_t k = 1; k < kFftLengthBy2; ++k) { - if (accum_spectra_.num_points_[k] == kPointsToAccumulate && - accum_spectra_.E2_[k] > 0.f) { - new_erle[k] = accum_spectra_.Y2_[k] / accum_spectra_.E2_[k]; - is_erle_updated[k] = true; +void SubbandErleEstimator::UpdateBands( + const std::vector& converged_filters) { + const int num_capture_channels = static_cast(accum_spectra_.Y2.size()); + for (int ch = 0; ch < num_capture_channels; ++ch) { + // Note that the use of the converged_filter flag already imposed + // a minimum of the erle that can be estimated as that flag would + // be false if the filter is performing poorly. + if (!converged_filters[ch]) { + continue; } - } - if (onset_detection) { + std::array new_erle; + std::array is_erle_updated; + is_erle_updated.fill(false); + for (size_t k = 1; k < kFftLengthBy2; ++k) { - if (is_erle_updated[k] && !accum_spectra_.low_render_energy_[k]) { - if (coming_onset_[k]) { - coming_onset_[k] = false; - if (!use_min_erle_during_onsets_) { - float alpha = new_erle[k] < erle_onsets_[k] ? 0.3f : 0.15f; - erle_onsets_[k] = rtc::SafeClamp( - erle_onsets_[k] + alpha * (new_erle[k] - erle_onsets_[k]), - min_erle_, max_erle_[k]); - } - } - hold_counters_[k] = kBlocksForOnsetDetection; + if (accum_spectra_.num_points[ch] == kPointsToAccumulate && + accum_spectra_.E2[ch][k] > 0.f) { + new_erle[k] = accum_spectra_.Y2[ch][k] / accum_spectra_.E2[ch][k]; + is_erle_updated[k] = true; } } - } - for (size_t k = 1; k < kFftLengthBy2; ++k) { - if (is_erle_updated[k]) { - float alpha = 0.05f; - if (new_erle[k] < erle_[0][k]) { - alpha = accum_spectra_.low_render_energy_[k] ? 0.f : 0.1f; + if (use_onset_detection_) { + for (size_t k = 1; k < kFftLengthBy2; ++k) { + if (is_erle_updated[k] && !accum_spectra_.low_render_energy[ch][k]) { + if (coming_onset_[ch][k]) { + coming_onset_[ch][k] = false; + if (!use_min_erle_during_onsets_) { + float alpha = new_erle[k] < erle_onsets_[ch][k] ? 0.3f : 0.15f; + erle_onsets_[ch][k] = rtc::SafeClamp( + erle_onsets_[ch][k] + + alpha * (new_erle[k] - erle_onsets_[ch][k]), + min_erle_, max_erle_[k]); + } + } + hold_counters_[ch][k] = kBlocksForOnsetDetection; + } + } + } + + for (size_t k = 1; k < kFftLengthBy2; ++k) { + if (is_erle_updated[k]) { + float alpha = 0.05f; + if (new_erle[k] < erle_[ch][k]) { + alpha = accum_spectra_.low_render_energy[ch][k] ? 0.f : 0.1f; + } + erle_[ch][k] = + rtc::SafeClamp(erle_[ch][k] + alpha * (new_erle[k] - erle_[ch][k]), + min_erle_, max_erle_[k]); } - erle_[0][k] = - rtc::SafeClamp(erle_[0][k] + alpha * (new_erle[k] - erle_[0][k]), - min_erle_, max_erle_[k]); } } } void SubbandErleEstimator::DecreaseErlePerBandForLowRenderSignals() { - for (size_t k = 1; k < kFftLengthBy2; ++k) { - hold_counters_[k]--; - if (hold_counters_[k] <= (kBlocksForOnsetDetection - kBlocksToHoldErle)) { - if (erle_[0][k] > erle_onsets_[k]) { - erle_[0][k] = std::max(erle_onsets_[k], 0.97f * erle_[0][k]); - RTC_DCHECK_LE(min_erle_, erle_[0][k]); - } - if (hold_counters_[k] <= 0) { - coming_onset_[k] = true; - hold_counters_[k] = 0; + const int num_capture_channels = static_cast(accum_spectra_.Y2.size()); + for (int ch = 0; ch < num_capture_channels; ++ch) { + for (size_t k = 1; k < kFftLengthBy2; ++k) { + --hold_counters_[ch][k]; + if (hold_counters_[ch][k] <= + (kBlocksForOnsetDetection - kBlocksToHoldErle)) { + if (erle_[ch][k] > erle_onsets_[ch][k]) { + erle_[ch][k] = std::max(erle_onsets_[ch][k], 0.97f * erle_[ch][k]); + RTC_DCHECK_LE(min_erle_, erle_[ch][k]); + } + if (hold_counters_[ch][k] <= 0) { + coming_onset_[ch][k] = true; + hold_counters_[ch][k] = 0; + } } } } } void SubbandErleEstimator::ResetAccumulatedSpectra() { - accum_spectra_.Y2_.fill(0.f); - accum_spectra_.E2_.fill(0.f); - accum_spectra_.num_points_.fill(0); - accum_spectra_.low_render_energy_.fill(false); + for (size_t ch = 0; ch < erle_onsets_.size(); ++ch) { + accum_spectra_.Y2[ch].fill(0.f); + accum_spectra_.E2[ch].fill(0.f); + accum_spectra_.num_points[ch] = 0; + accum_spectra_.low_render_energy[ch].fill(false); + } } void SubbandErleEstimator::UpdateAccumulatedSpectra( - rtc::ArrayView X2, - rtc::ArrayView Y2, - rtc::ArrayView E2) { + rtc::ArrayView X2, + rtc::ArrayView> Y2, + rtc::ArrayView> E2, + const std::vector& converged_filters) { auto& st = accum_spectra_; - if (st.num_points_[0] == kPointsToAccumulate) { - st.num_points_[0] = 0; - st.Y2_.fill(0.f); - st.E2_.fill(0.f); - st.low_render_energy_.fill(false); - } - std::transform(Y2.begin(), Y2.end(), st.Y2_.begin(), st.Y2_.begin(), - std::plus()); - std::transform(E2.begin(), E2.end(), st.E2_.begin(), st.E2_.begin(), - std::plus()); + RTC_DCHECK_EQ(st.E2.size(), E2.size()); + RTC_DCHECK_EQ(st.E2.size(), E2.size()); + const int num_capture_channels = static_cast(Y2.size()); + for (int ch = 0; ch < num_capture_channels; ++ch) { + // Note that the use of the converged_filter flag already imposed + // a minimum of the erle that can be estimated as that flag would + // be false if the filter is performing poorly. + if (!converged_filters[ch]) { + continue; + } - for (size_t k = 0; k < X2.size(); ++k) { - st.low_render_energy_[k] = - st.low_render_energy_[k] || X2[k] < kX2BandEnergyThreshold; + if (st.num_points[ch] == kPointsToAccumulate) { + st.num_points[ch] = 0; + st.Y2[ch].fill(0.f); + st.E2[ch].fill(0.f); + st.low_render_energy[ch].fill(false); + } + + std::transform(Y2[ch].begin(), Y2[ch].end(), st.Y2[ch].begin(), + st.Y2[ch].begin(), std::plus()); + std::transform(E2[ch].begin(), E2[ch].end(), st.E2[ch].begin(), + st.E2[ch].begin(), std::plus()); + + for (size_t k = 0; k < X2.size(); ++k) { + st.low_render_energy[ch][k] = + st.low_render_energy[ch][k] || X2[k] < kX2BandEnergyThreshold; + } + + ++st.num_points[ch]; } - st.num_points_[0]++; - st.num_points_.fill(st.num_points_[0]); } } // namespace webrtc diff --git a/modules/audio_processing/aec3/subband_erle_estimator.h b/modules/audio_processing/aec3/subband_erle_estimator.h index 18bab7d138..90363e081d 100644 --- a/modules/audio_processing/aec3/subband_erle_estimator.h +++ b/modules/audio_processing/aec3/subband_erle_estimator.h @@ -35,47 +35,57 @@ class SubbandErleEstimator { void Reset(); // Updates the ERLE estimate. - void Update(rtc::ArrayView X2, - rtc::ArrayView Y2, - rtc::ArrayView E2, - bool converged_filter, - bool onset_detection); + void Update(rtc::ArrayView X2, + rtc::ArrayView> Y2, + rtc::ArrayView> E2, + const std::vector& converged_filters); // Returns the ERLE estimate. rtc::ArrayView> Erle() const { return erle_; } - // Returns the ERLE estimate at onsets. - rtc::ArrayView ErleOnsets() const { return erle_onsets_; } + // Returns the ERLE estimate at onsets (only used for testing). + rtc::ArrayView> ErleOnsets() + const { + return erle_onsets_; + } void Dump(const std::unique_ptr& data_dumper) const; private: struct AccumulatedSpectra { - std::array Y2_; - std::array E2_; - std::array low_render_energy_; - std::array num_points_; + explicit AccumulatedSpectra(size_t num_capture_channels) + : Y2(num_capture_channels), + E2(num_capture_channels), + low_render_energy(num_capture_channels), + num_points(num_capture_channels) {} + std::vector> Y2; + std::vector> E2; + std::vector> low_render_energy; + std::vector num_points; }; - void UpdateAccumulatedSpectra(rtc::ArrayView X2, - rtc::ArrayView Y2, - rtc::ArrayView E2); + void UpdateAccumulatedSpectra( + rtc::ArrayView X2, + rtc::ArrayView> Y2, + rtc::ArrayView> E2, + const std::vector& converged_filters); void ResetAccumulatedSpectra(); - void UpdateBands(bool onset_detection); + void UpdateBands(const std::vector& converged_filters); void DecreaseErlePerBandForLowRenderSignals(); + const bool use_onset_detection_; const float min_erle_; const std::array max_erle_; const bool use_min_erle_during_onsets_; AccumulatedSpectra accum_spectra_; std::vector> erle_; - std::array erle_onsets_; - std::array coming_onset_; - std::array hold_counters_; + std::vector> erle_onsets_; + std::vector> coming_onset_; + std::vector> hold_counters_; }; } // namespace webrtc diff --git a/modules/audio_processing/aec3/subtractor_output_analyzer.cc b/modules/audio_processing/aec3/subtractor_output_analyzer.cc index 9a0e0bbf7b..cf16001153 100644 --- a/modules/audio_processing/aec3/subtractor_output_analyzer.cc +++ b/modules/audio_processing/aec3/subtractor_output_analyzer.cc @@ -16,26 +16,41 @@ namespace webrtc { -SubtractorOutputAnalyzer::SubtractorOutputAnalyzer() {} +SubtractorOutputAnalyzer::SubtractorOutputAnalyzer(size_t num_capture_channels) + : filters_converged_(num_capture_channels, false) {} void SubtractorOutputAnalyzer::Update( - const SubtractorOutput& subtractor_output) { - const float y2 = subtractor_output.y2; - const float e2_main = subtractor_output.e2_main; - const float e2_shadow = subtractor_output.e2_shadow; + rtc::ArrayView subtractor_output, + bool* any_filter_converged, + bool* all_filters_diverged) { + RTC_DCHECK(any_filter_converged); + RTC_DCHECK(all_filters_diverged); + RTC_DCHECK_EQ(subtractor_output.size(), filters_converged_.size()); - constexpr float kConvergenceThreshold = 50 * 50 * kBlockSize; - main_filter_converged_ = e2_main < 0.5f * y2 && y2 > kConvergenceThreshold; - shadow_filter_converged_ = - e2_shadow < 0.05f * y2 && y2 > kConvergenceThreshold; - float min_e2 = std::min(e2_main, e2_shadow); - filter_diverged_ = min_e2 > 1.5f * y2 && y2 > 30.f * 30.f * kBlockSize; + *any_filter_converged = false; + *all_filters_diverged = true; + + for (size_t ch = 0; ch < subtractor_output.size(); ++ch) { + const float y2 = subtractor_output[ch].y2; + const float e2_main = subtractor_output[ch].e2_main; + const float e2_shadow = subtractor_output[ch].e2_shadow; + + constexpr float kConvergenceThreshold = 50 * 50 * kBlockSize; + bool main_filter_converged = + e2_main < 0.5f * y2 && y2 > kConvergenceThreshold; + bool shadow_filter_converged = + e2_shadow < 0.05f * y2 && y2 > kConvergenceThreshold; + float min_e2 = std::min(e2_main, e2_shadow); + bool filter_diverged = min_e2 > 1.5f * y2 && y2 > 30.f * 30.f * kBlockSize; + filters_converged_[ch] = main_filter_converged || shadow_filter_converged; + + *any_filter_converged = *any_filter_converged || filters_converged_[ch]; + *all_filters_diverged = *all_filters_diverged && filter_diverged; + } } void SubtractorOutputAnalyzer::HandleEchoPathChange() { - shadow_filter_converged_ = false; - main_filter_converged_ = false; - filter_diverged_ = false; + std::fill(filters_converged_.begin(), filters_converged_.end(), false); } } // namespace webrtc diff --git a/modules/audio_processing/aec3/subtractor_output_analyzer.h b/modules/audio_processing/aec3/subtractor_output_analyzer.h index 76a25604d3..5328ae7f1e 100644 --- a/modules/audio_processing/aec3/subtractor_output_analyzer.h +++ b/modules/audio_processing/aec3/subtractor_output_analyzer.h @@ -11,32 +11,32 @@ #ifndef MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_OUTPUT_ANALYZER_H_ #define MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_OUTPUT_ANALYZER_H_ +#include + #include "modules/audio_processing/aec3/subtractor_output.h" namespace webrtc { -// Class for analyzing the properties subtractor output +// Class for analyzing the properties subtractor output. class SubtractorOutputAnalyzer { public: - SubtractorOutputAnalyzer(); + explicit SubtractorOutputAnalyzer(size_t num_capture_channels); ~SubtractorOutputAnalyzer() = default; // Analyses the subtractor output. - void Update(const SubtractorOutput& subtractor_output); + void Update(rtc::ArrayView subtractor_output, + bool* any_filter_converged, + bool* all_filters_diverged); - bool ConvergedFilter() const { - return main_filter_converged_ || shadow_filter_converged_; + const std::vector& ConvergedFilters() const { + return filters_converged_; } - bool DivergedFilter() const { return filter_diverged_; } - // Handle echo path change. void HandleEchoPathChange(); private: - bool shadow_filter_converged_ = false; - bool main_filter_converged_ = false; - bool filter_diverged_ = false; + std::vector filters_converged_; }; } // namespace webrtc