diff --git a/modules/audio_processing/aec3/render_buffer.cc b/modules/audio_processing/aec3/render_buffer.cc index 6e224be703..235e3e335a 100644 --- a/modules/audio_processing/aec3/render_buffer.cc +++ b/modules/audio_processing/aec3/render_buffer.cc @@ -47,4 +47,28 @@ void RenderBuffer::SpectralSum( } } +void RenderBuffer::SpectralSums( + size_t num_spectra_shorter, + size_t num_spectra_longer, + std::array* X2_shorter, + std::array* X2_longer) const { + RTC_DCHECK_LE(num_spectra_shorter, num_spectra_longer); + X2_shorter->fill(0.f); + int position = spectrum_buffer_->read; + size_t j = 0; + for (; j < num_spectra_shorter; ++j) { + std::transform(X2_shorter->begin(), X2_shorter->end(), + spectrum_buffer_->buffer[position].begin(), + X2_shorter->begin(), std::plus()); + position = spectrum_buffer_->IncIndex(position); + } + std::copy(X2_shorter->begin(), X2_shorter->end(), X2_longer->begin()); + for (; j < num_spectra_longer; ++j) { + std::transform(X2_longer->begin(), X2_longer->end(), + spectrum_buffer_->buffer[position].begin(), + X2_longer->begin(), std::plus()); + position = spectrum_buffer_->IncIndex(position); + } +} + } // namespace webrtc diff --git a/modules/audio_processing/aec3/render_buffer.h b/modules/audio_processing/aec3/render_buffer.h index 34e7edf37a..dd67268efd 100644 --- a/modules/audio_processing/aec3/render_buffer.h +++ b/modules/audio_processing/aec3/render_buffer.h @@ -61,6 +61,12 @@ class RenderBuffer { void SpectralSum(size_t num_spectra, std::array* X2) const; + // Returns the sums of the spectrums for two numbers of FFTs. + void SpectralSums(size_t num_spectra_shorter, + size_t num_spectra_longer, + std::array* X2_shorter, + std::array* X2_longer) const; + // Gets the recent activity seen in the render signal. bool GetRenderActivity() const { return render_activity_; } diff --git a/modules/audio_processing/aec3/subtractor.cc b/modules/audio_processing/aec3/subtractor.cc index 4477376f1a..b2fc993f48 100644 --- a/modules/audio_processing/aec3/subtractor.cc +++ b/modules/audio_processing/aec3/subtractor.cc @@ -128,12 +128,6 @@ Subtractor::Subtractor(const EchoCanceller3Config& config, G_shadow_(config_.filter.shadow_initial, config.filter.config_change_duration_blocks) { RTC_DCHECK(data_dumper_); - // Currently, the rest of AEC3 requires the main and shadow filter lengths to - // be identical. - RTC_DCHECK_EQ(config_.filter.main.length_blocks, - config_.filter.shadow.length_blocks); - RTC_DCHECK_EQ(config_.filter.main_initial.length_blocks, - config_.filter.shadow_initial.length_blocks); } Subtractor::~Subtractor() = default; @@ -222,11 +216,28 @@ void Subtractor::Process(const RenderBuffer& render_buffer, E_shadow.Spectrum(optimization_, output->E2_shadow); E_main.Spectrum(optimization_, output->E2_main); + // Compute the render powers. + std::array X2_main; + std::array X2_shadow_data; + std::array& X2_shadow = + main_filter_.SizePartitions() == shadow_filter_.SizePartitions() + ? X2_main + : X2_shadow_data; + if (main_filter_.SizePartitions() == shadow_filter_.SizePartitions()) { + render_buffer.SpectralSum(main_filter_.SizePartitions(), &X2_main); + } else if (main_filter_.SizePartitions() > shadow_filter_.SizePartitions()) { + render_buffer.SpectralSums(shadow_filter_.SizePartitions(), + main_filter_.SizePartitions(), &X2_shadow, + &X2_main); + } else { + render_buffer.SpectralSums(main_filter_.SizePartitions(), + shadow_filter_.SizePartitions(), &X2_main, + &X2_shadow); + } + // Update the main filter. - std::array X2; - render_buffer.SpectralSum(main_filter_.SizePartitions(), &X2); if (!main_filter_adjusted) { - G_main_.Compute(X2, render_signal_analyzer, *output, main_filter_, + G_main_.Compute(X2_main, render_signal_analyzer, *output, main_filter_, aec_state.SaturatedCapture() || main_saturation, &G); } else { G.re.fill(0.f); @@ -244,19 +255,15 @@ void Subtractor::Process(const RenderBuffer& render_buffer, (poor_shadow_filter_counter_ < 10 && !enable_early_shadow_filter_jumpstart_)) || !enable_shadow_filter_jumpstart_) { - if (shadow_filter_.SizePartitions() != main_filter_.SizePartitions()) { - render_buffer.SpectralSum(shadow_filter_.SizePartitions(), &X2); - } - G_shadow_.Compute(X2, render_signal_analyzer, E_shadow, + G_shadow_.Compute(X2_shadow, render_signal_analyzer, E_shadow, shadow_filter_.SizePartitions(), aec_state.SaturatedCapture() || shadow_saturation, &G); shadow_filter_.Adapt(render_buffer, G); } else { poor_shadow_filter_counter_ = 0; - if (enable_shadow_filter_boosted_jumpstart_) { shadow_filter_.SetFilter(main_filter_.GetFilter()); - G_shadow_.Compute(X2, render_signal_analyzer, E_main, + G_shadow_.Compute(X2_shadow, render_signal_analyzer, E_main, shadow_filter_.SizePartitions(), aec_state.SaturatedCapture() || main_saturation, &G); shadow_filter_.Adapt(render_buffer, G); diff --git a/modules/audio_processing/aec3/subtractor_unittest.cc b/modules/audio_processing/aec3/subtractor_unittest.cc index 35abf1cd2a..77918056b4 100644 --- a/modules/audio_processing/aec3/subtractor_unittest.cc +++ b/modules/audio_processing/aec3/subtractor_unittest.cc @@ -25,13 +25,15 @@ namespace { float RunSubtractorTest(int num_blocks_to_process, int delay_samples, - int filter_length_blocks, + int main_filter_length_blocks, + int shadow_filter_length_blocks, bool uncorrelated_inputs, const std::vector& blocks_with_echo_path_changes) { ApmDataDumper data_dumper(42); EchoCanceller3Config config; - config.filter.main.length_blocks = config.filter.shadow.length_blocks = - filter_length_blocks; + config.filter.main.length_blocks = main_filter_length_blocks; + config.filter.shadow.length_blocks = shadow_filter_length_blocks; + Subtractor subtractor(config, &data_dumper, DetectOptimization()); absl::optional delay_estimate; std::vector> x(3, std::vector(kBlockSize, 0.f)); @@ -160,9 +162,9 @@ TEST(Subtractor, Convergence) { for (size_t delay_samples : {0, 64, 150, 200, 301}) { SCOPED_TRACE(ProduceDebugText(delay_samples, filter_length_blocks)); - float echo_to_nearend_power = - RunSubtractorTest(400, delay_samples, filter_length_blocks, false, - blocks_with_echo_path_changes); + float echo_to_nearend_power = RunSubtractorTest( + 400, delay_samples, filter_length_blocks, filter_length_blocks, false, + blocks_with_echo_path_changes); // Use different criteria to take overmodelling into account. if (filter_length_blocks == 12) { @@ -174,6 +176,24 @@ TEST(Subtractor, Convergence) { } } +// Verifies that the subtractor is able to handle the case when the main filter +// is longer than the shadow filter. +TEST(Subtractor, MainFilterLongerThanShadowFilter) { + std::vector blocks_with_echo_path_changes; + float echo_to_nearend_power = + RunSubtractorTest(400, 64, 20, 15, false, blocks_with_echo_path_changes); + EXPECT_GT(0.5f, echo_to_nearend_power); +} + +// Verifies that the subtractor is able to handle the case when the shadow +// filter is longer than the main filter. +TEST(Subtractor, ShadowFilterLongerThanMainFilter) { + std::vector blocks_with_echo_path_changes; + float echo_to_nearend_power = + RunSubtractorTest(400, 64, 15, 20, false, blocks_with_echo_path_changes); + EXPECT_GT(0.5f, echo_to_nearend_power); +} + // Verifies that the subtractor does not converge on uncorrelated signals. TEST(Subtractor, NonConvergenceOnUncorrelatedSignals) { std::vector blocks_with_echo_path_changes; @@ -181,9 +201,9 @@ TEST(Subtractor, NonConvergenceOnUncorrelatedSignals) { for (size_t delay_samples : {0, 64, 150, 200, 301}) { SCOPED_TRACE(ProduceDebugText(delay_samples, filter_length_blocks)); - float echo_to_nearend_power = - RunSubtractorTest(300, delay_samples, filter_length_blocks, true, - blocks_with_echo_path_changes); + float echo_to_nearend_power = RunSubtractorTest( + 300, delay_samples, filter_length_blocks, filter_length_blocks, true, + blocks_with_echo_path_changes); EXPECT_NEAR(1.f, echo_to_nearend_power, 0.1); } } @@ -198,9 +218,9 @@ TEST(Subtractor, EchoPathChangeReset) { for (size_t delay_samples : {0, 64, 150, 200, 301}) { SCOPED_TRACE(ProduceDebugText(delay_samples, filter_length_blocks)); - float echo_to_nearend_power = - RunSubtractorTest(100, delay_samples, filter_length_blocks, false, - blocks_with_echo_path_changes); + float echo_to_nearend_power = RunSubtractorTest( + 100, delay_samples, filter_length_blocks, filter_length_blocks, false, + blocks_with_echo_path_changes); EXPECT_NEAR(1.f, echo_to_nearend_power, 0.0000001f); } }