diff --git a/modules/audio_processing/aec3/BUILD.gn b/modules/audio_processing/aec3/BUILD.gn index d4980d7e00..3cc2ef5b90 100644 --- a/modules/audio_processing/aec3/BUILD.gn +++ b/modules/audio_processing/aec3/BUILD.gn @@ -97,6 +97,8 @@ rtc_static_library("aec3") { "subtractor.cc", "subtractor.h", "subtractor_output.h", + "subtractor_output_analyzer.cc", + "subtractor_output_analyzer.h", "suppression_filter.cc", "suppression_filter.h", "suppression_gain.cc", diff --git a/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc b/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc index 15717d1796..256371a4e1 100644 --- a/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc +++ b/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc @@ -328,7 +328,7 @@ TEST(AdaptiveFirFilter, FilterAndAdapt) { absl::optional delay_estimate; std::vector e(kBlockSize, 0.f); std::array s_scratch; - std::array s; + SubtractorOutput output; FftData S; FftData G; FftData E; @@ -342,6 +342,7 @@ TEST(AdaptiveFirFilter, FilterAndAdapt) { Y2.fill(0.f); E2_main.fill(0.f); E2_shadow.fill(0.f); + output.Reset(); constexpr float kScale = 1.0f / kFftLengthBy2; @@ -382,7 +383,7 @@ TEST(AdaptiveFirFilter, FilterAndAdapt) { [](float& a) { a = rtc::SafeClamp(a, -32768.f, 32767.f); }); fft.ZeroPaddedFft(e, Aec3Fft::Window::kRectangular, &E); for (size_t k = 0; k < kBlockSize; ++k) { - s[k] = kScale * s_scratch[k + kFftLengthBy2]; + output.s_main[k] = kScale * s_scratch[k + kFftLengthBy2]; } std::array render_power; @@ -394,8 +395,8 @@ TEST(AdaptiveFirFilter, FilterAndAdapt) { false, EchoPathVariability::DelayAdjustment::kNone, false)); aec_state.Update(delay_estimate, filter.FilterFrequencyResponse(), - filter.FilterImpulseResponse(), true, false, - *render_buffer, E2_main, Y2, s); + filter.FilterImpulseResponse(), *render_buffer, E2_main, + Y2, output, y); } // Verify that the filter is able to perform well. EXPECT_LT(1000 * std::inner_product(e.begin(), e.end(), e.begin(), 0.f), diff --git a/modules/audio_processing/aec3/aec_state.cc b/modules/audio_processing/aec3/aec_state.cc index b03a121555..f6ce23bdef 100644 --- a/modules/audio_processing/aec3/aec_state.cc +++ b/modules/audio_processing/aec3/aec_state.cc @@ -122,6 +122,8 @@ void AecState::HandleEchoPathChange( } else if (echo_path_variability.gain_change) { blocks_since_reset_ = kNumBlocksPerSecond; } + + subtractor_output_analyzer_.HandleEchoPathChange(); } void AecState::Update( @@ -129,12 +131,17 @@ void AecState::Update( const std::vector>& adaptive_filter_frequency_response, const std::vector& adaptive_filter_impulse_response, - bool converged_filter, - bool diverged_filter, const RenderBuffer& render_buffer, const std::array& E2_main, const std::array& Y2, - const std::array& s) { + const SubtractorOutput& subtractor_output, + rtc::ArrayView y) { + // Analyze the filter output. + subtractor_output_analyzer_.Update(y, subtractor_output); + + const bool converged_filter = subtractor_output_analyzer_.ConvergedFilter(); + const bool diverged_filter = subtractor_output_analyzer_.DivergedFilter(); + // Analyze the filter and compute the delays. filter_analyzer_.Update(adaptive_filter_impulse_response, adaptive_filter_frequency_response, render_buffer); diff --git a/modules/audio_processing/aec3/aec_state.h b/modules/audio_processing/aec3/aec_state.h index caccdf7412..091cd1fd9e 100644 --- a/modules/audio_processing/aec3/aec_state.h +++ b/modules/audio_processing/aec3/aec_state.h @@ -29,6 +29,8 @@ #include "modules/audio_processing/aec3/filter_analyzer.h" #include "modules/audio_processing/aec3/render_buffer.h" #include "modules/audio_processing/aec3/reverb_model_estimator.h" +#include "modules/audio_processing/aec3/subtractor_output.h" +#include "modules/audio_processing/aec3/subtractor_output_analyzer.h" #include "modules/audio_processing/aec3/suppression_gain_limiter.h" #include "rtc_base/constructormagic.h" @@ -151,12 +153,11 @@ class AecState { const std::vector>& adaptive_filter_frequency_response, const std::vector& adaptive_filter_impulse_response, - bool converged_filter, - bool diverged_filter, const RenderBuffer& render_buffer, const std::array& E2_main, const std::array& Y2, - const std::array& s); + const SubtractorOutput& subtractor_output, + rtc::ArrayView y); // Returns the tail freq. response of the linear filter. rtc::ArrayView GetFreqRespTail() const { @@ -218,6 +219,7 @@ class AecState { size_t active_blocks_since_converged_filter_ = 0; EchoAudibility echo_audibility_; ReverbModelEstimator reverb_model_estimator_; + SubtractorOutputAnalyzer subtractor_output_analyzer_; RTC_DISALLOW_COPY_AND_ASSIGN(AecState); }; diff --git a/modules/audio_processing/aec3/aec_state_unittest.cc b/modules/audio_processing/aec3/aec_state_unittest.cc index 00d25d61df..6111979f82 100644 --- a/modules/audio_processing/aec3/aec_state_unittest.cc +++ b/modules/audio_processing/aec3/aec_state_unittest.cc @@ -31,9 +31,12 @@ TEST(AecState, NormalUsage) { std::vector> x(3, std::vector(kBlockSize, 0.f)); EchoPathVariability echo_path_variability( false, EchoPathVariability::DelayAdjustment::kNone, false); - std::array s; + SubtractorOutput output; + std::array y; Aec3Fft fft; - s.fill(100.f); + output.s_main.fill(100.f); + output.e_main.fill(100.f); + y.fill(1000.f); std::vector> converged_filter_frequency_response(10); @@ -50,8 +53,8 @@ TEST(AecState, NormalUsage) { // Verify that linear AEC usability is false when the filter is diverged. state.Update(delay_estimate, diverged_filter_frequency_response, - impulse_response, true, false, - *render_delay_buffer->GetRenderBuffer(), E2_main, Y2, s); + impulse_response, *render_delay_buffer->GetRenderBuffer(), + E2_main, Y2, output, y); EXPECT_FALSE(state.UsableLinearEstimate()); // Verify that linear AEC usability is true when the filter is converged @@ -59,8 +62,8 @@ TEST(AecState, NormalUsage) { for (int k = 0; k < 3000; ++k) { render_delay_buffer->Insert(x); state.Update(delay_estimate, converged_filter_frequency_response, - impulse_response, true, false, - *render_delay_buffer->GetRenderBuffer(), E2_main, Y2, s); + impulse_response, *render_delay_buffer->GetRenderBuffer(), + E2_main, Y2, output, y); } EXPECT_TRUE(state.UsableLinearEstimate()); @@ -69,8 +72,8 @@ TEST(AecState, NormalUsage) { state.HandleEchoPathChange(EchoPathVariability( true, EchoPathVariability::DelayAdjustment::kNone, false)); state.Update(delay_estimate, converged_filter_frequency_response, - impulse_response, true, false, - *render_delay_buffer->GetRenderBuffer(), E2_main, Y2, s); + impulse_response, *render_delay_buffer->GetRenderBuffer(), + E2_main, Y2, output, y); EXPECT_FALSE(state.UsableLinearEstimate()); // Verify that the active render detection works as intended. @@ -79,15 +82,15 @@ TEST(AecState, NormalUsage) { state.HandleEchoPathChange(EchoPathVariability( true, EchoPathVariability::DelayAdjustment::kNewDetectedDelay, false)); state.Update(delay_estimate, converged_filter_frequency_response, - impulse_response, true, false, - *render_delay_buffer->GetRenderBuffer(), E2_main, Y2, s); + impulse_response, *render_delay_buffer->GetRenderBuffer(), + E2_main, Y2, output, y); EXPECT_FALSE(state.ActiveRender()); for (int k = 0; k < 1000; ++k) { render_delay_buffer->Insert(x); state.Update(delay_estimate, converged_filter_frequency_response, - impulse_response, true, false, - *render_delay_buffer->GetRenderBuffer(), E2_main, Y2, s); + impulse_response, *render_delay_buffer->GetRenderBuffer(), + E2_main, Y2, output, y); } EXPECT_TRUE(state.ActiveRender()); @@ -109,8 +112,8 @@ TEST(AecState, NormalUsage) { Y2.fill(10.f * 10000.f * 10000.f); for (size_t k = 0; k < 1000; ++k) { state.Update(delay_estimate, converged_filter_frequency_response, - impulse_response, true, false, - *render_delay_buffer->GetRenderBuffer(), E2_main, Y2, s); + impulse_response, *render_delay_buffer->GetRenderBuffer(), + E2_main, Y2, output, y); } ASSERT_TRUE(state.UsableLinearEstimate()); @@ -126,8 +129,8 @@ TEST(AecState, NormalUsage) { Y2.fill(10.f * E2_main[0]); for (size_t k = 0; k < 1000; ++k) { state.Update(delay_estimate, converged_filter_frequency_response, - impulse_response, true, false, - *render_delay_buffer->GetRenderBuffer(), E2_main, Y2, s); + impulse_response, *render_delay_buffer->GetRenderBuffer(), + E2_main, Y2, output, y); } ASSERT_TRUE(state.UsableLinearEstimate()); { @@ -147,8 +150,8 @@ TEST(AecState, NormalUsage) { Y2.fill(5.f * E2_main[0]); for (size_t k = 0; k < 1000; ++k) { state.Update(delay_estimate, converged_filter_frequency_response, - impulse_response, true, false, - *render_delay_buffer->GetRenderBuffer(), E2_main, Y2, s); + impulse_response, *render_delay_buffer->GetRenderBuffer(), + E2_main, Y2, output, y); } ASSERT_TRUE(state.UsableLinearEstimate()); @@ -179,9 +182,11 @@ TEST(AecState, ConvergedFilterDelay) { std::array x; EchoPathVariability echo_path_variability( false, EchoPathVariability::DelayAdjustment::kNone, false); - std::array s; - s.fill(100.f); + SubtractorOutput output; + std::array y; + output.s_main.fill(100.f); x.fill(0.f); + y.fill(0.f); std::vector> frequency_response( kFilterLengthBlocks); @@ -198,9 +203,9 @@ TEST(AecState, ConvergedFilterDelay) { impulse_response[k * kBlockSize + 1] = 1.f; state.HandleEchoPathChange(echo_path_variability); - state.Update(delay_estimate, frequency_response, impulse_response, true, - false, *render_delay_buffer->GetRenderBuffer(), E2_main, Y2, - s); + state.Update(delay_estimate, frequency_response, impulse_response, + *render_delay_buffer->GetRenderBuffer(), E2_main, Y2, output, + y); } } diff --git a/modules/audio_processing/aec3/echo_remover.cc b/modules/audio_processing/aec3/echo_remover.cc index d382d93704..2e692980fb 100644 --- a/modules/audio_processing/aec3/echo_remover.cc +++ b/modules/audio_processing/aec3/echo_remover.cc @@ -210,9 +210,8 @@ void EchoRemoverImpl::ProcessCapture( // Update the AEC state information. aec_state_.Update(external_delay, subtractor_.FilterFrequencyResponse(), - subtractor_.FilterImpulseResponse(), - subtractor_.ConvergedFilter(), subtractor_.DivergedFilter(), - *render_buffer, E2, Y2, subtractor_output.s_main); + subtractor_.FilterImpulseResponse(), *render_buffer, E2, Y2, + subtractor_output, y0); // Compute spectra. const bool suppression_gain_uses_ffts = diff --git a/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc b/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc index e9d768b158..bed148a57d 100644 --- a/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc +++ b/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc @@ -162,8 +162,9 @@ void RunFilterUpdateTest(int num_blocks_to_process, aec_state.HandleEchoPathChange(EchoPathVariability( false, EchoPathVariability::DelayAdjustment::kNone, false)); aec_state.Update(delay_estimate, main_filter.FilterFrequencyResponse(), - main_filter.FilterImpulseResponse(), true, false, - *render_delay_buffer->GetRenderBuffer(), E2_main, Y2, s); + main_filter.FilterImpulseResponse(), + *render_delay_buffer->GetRenderBuffer(), E2_main, Y2, + output, y); } std::copy(e_main.begin(), e_main.end(), e_last_block->begin()); diff --git a/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc b/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc index 30f3ddcf96..832d8ca1d3 100644 --- a/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc +++ b/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc @@ -61,7 +61,8 @@ TEST(ResidualEchoEstimator, DISABLED_BasicTest) { std::vector> x(3, std::vector(kBlockSize, 0.f)); std::vector> H2(10); Random random_generator(42U); - std::array s; + SubtractorOutput output; + std::array y; Aec3Fft fft; absl::optional delay_estimate; @@ -74,7 +75,9 @@ TEST(ResidualEchoEstimator, DISABLED_BasicTest) { std::vector h(GetTimeDomainLength(config.filter.main.length_blocks), 0.f); - s.fill(100.f); + output.Reset(); + output.s_main.fill(100.f); + y.fill(0.f); constexpr float kLevel = 10.f; E2_shadow.fill(kLevel); @@ -93,8 +96,9 @@ TEST(ResidualEchoEstimator, DISABLED_BasicTest) { render_delay_buffer->PrepareCaptureProcessing(); aec_state.HandleEchoPathChange(echo_path_variability); - aec_state.Update(delay_estimate, H2, h, true, false, - *render_delay_buffer->GetRenderBuffer(), E2_main, Y2, s); + aec_state.Update(delay_estimate, H2, h, + *render_delay_buffer->GetRenderBuffer(), E2_main, Y2, + output, y); estimator.Estimate(aec_state, *render_delay_buffer->GetRenderBuffer(), S2_linear, Y2, &R2); diff --git a/modules/audio_processing/aec3/subtractor.cc b/modules/audio_processing/aec3/subtractor.cc index b1eba18fbc..a9a3e6faa3 100644 --- a/modules/audio_processing/aec3/subtractor.cc +++ b/modules/audio_processing/aec3/subtractor.cc @@ -111,11 +111,8 @@ void Subtractor::HandleEchoPathChange( G_shadow_.HandleEchoPathChange(); G_main_.SetConfig(config_.filter.main_initial, true); G_shadow_.SetConfig(config_.filter.shadow_initial, true); - main_filter_converged_ = false; - shadow_filter_converged_ = false; main_filter_.SetSizePartitions(config_.filter.main_initial.length_blocks, true); - main_filter_once_converged_ = false; shadow_filter_.SetSizePartitions( config_.filter.shadow_initial.length_blocks, true); }; @@ -171,24 +168,8 @@ void Subtractor::Process(const RenderBuffer& render_buffer, &shadow_saturation); fft_.ZeroPaddedFft(e_shadow, Aec3Fft::Window::kHanning, &E_shadow); - // Check for filter convergence. - const auto sum_of_squares = [](float a, float b) { return a + b * b; }; - const float y2 = std::accumulate(y.begin(), y.end(), 0.f, sum_of_squares); - const float e2_main = - std::accumulate(e_main.begin(), e_main.end(), 0.f, sum_of_squares); - const float e2_shadow = - std::accumulate(e_shadow.begin(), e_shadow.end(), 0.f, sum_of_squares); - - constexpr float kConvergenceThreshold = 50 * 50 * kBlockSize; - main_filter_converged_ = e2_main < 0.5f * y2 && y2 > kConvergenceThreshold; - shadow_filter_converged_ = - e2_shadow < 0.05 * y2 && y2 > kConvergenceThreshold; - main_filter_once_converged_ = - main_filter_once_converged_ || main_filter_converged_; - main_filter_diverged_ = e2_main > 1.5f * y2 && y2 > 30.f * 30.f * kBlockSize; - if (enable_misadjustment_estimator_) { - filter_misadjustment_estimator_.Update(e2_main, y2); + filter_misadjustment_estimator_.Update(e_main, y); if (filter_misadjustment_estimator_.IsAdjustmentNeeded()) { float scale = filter_misadjustment_estimator_.GetMisadjustment(); main_filter_.ScaleFilter(scale); @@ -229,7 +210,13 @@ void Subtractor::Process(const RenderBuffer& render_buffer, } } -void Subtractor::FilterMisadjustmentEstimator::Update(float e2, float y2) { +void Subtractor::FilterMisadjustmentEstimator::Update( + rtc::ArrayView e, + rtc::ArrayView y) { + const auto sum_of_squares = [](float a, float b) { return a + b * b; }; + const float y2 = std::accumulate(y.begin(), y.end(), 0.f, sum_of_squares); + const float e2 = std::accumulate(e.begin(), e.end(), 0.f, sum_of_squares); + e2_acum_ += e2; y2_acum_ += y2; if (++n_blocks_acum_ == n_blocks_) { diff --git a/modules/audio_processing/aec3/subtractor.h b/modules/audio_processing/aec3/subtractor.h index 8ecda2db46..67ecdc02c6 100644 --- a/modules/audio_processing/aec3/subtractor.h +++ b/modules/audio_processing/aec3/subtractor.h @@ -54,24 +54,14 @@ class Subtractor { // Returns the block-wise frequency response for the main adaptive filter. const std::vector>& FilterFrequencyResponse() const { - return main_filter_once_converged_ || (!shadow_filter_converged_) - ? main_filter_.FilterFrequencyResponse() - : shadow_filter_.FilterFrequencyResponse(); + return main_filter_.FilterFrequencyResponse(); } // Returns the estimate of the impulse response for the main adaptive filter. const std::vector& FilterImpulseResponse() const { - return main_filter_once_converged_ || (!shadow_filter_converged_) - ? main_filter_.FilterImpulseResponse() - : shadow_filter_.FilterImpulseResponse(); + return main_filter_.FilterImpulseResponse(); } - bool ConvergedFilter() const { - return main_filter_converged_ || shadow_filter_converged_; - } - - bool DivergedFilter() const { return main_filter_diverged_; } - void DumpFilters() { main_filter_.DumpFilter("aec3_subtractor_H_main", "aec3_subtractor_h_main"); shadow_filter_.DumpFilter("aec3_subtractor_H_shadow", @@ -84,7 +74,7 @@ class Subtractor { FilterMisadjustmentEstimator() = default; ~FilterMisadjustmentEstimator() = default; // Update the misadjustment estimator. - void Update(float e2, float y2); + void Update(rtc::ArrayView e, rtc::ArrayView y); // GetMisadjustment() Returns a recommended scale for the filter so the // prediction error energy gets closer to the energy that is seen at the // microphone input. @@ -120,10 +110,6 @@ class Subtractor { AdaptiveFirFilter shadow_filter_; MainFilterUpdateGain G_main_; ShadowFilterUpdateGain G_shadow_; - bool main_filter_converged_ = false; - bool main_filter_once_converged_ = false; - bool shadow_filter_converged_ = false; - bool main_filter_diverged_ = false; FilterMisadjustmentEstimator filter_misadjustment_estimator_; RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(Subtractor); }; diff --git a/modules/audio_processing/aec3/subtractor_output_analyzer.cc b/modules/audio_processing/aec3/subtractor_output_analyzer.cc new file mode 100644 index 0000000000..3446c5a54e --- /dev/null +++ b/modules/audio_processing/aec3/subtractor_output_analyzer.cc @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/subtractor_output_analyzer.h" + +#include +#include + +namespace webrtc { + +void SubtractorOutputAnalyzer::Update( + rtc::ArrayView y, + const SubtractorOutput& subtractor_output) { + const auto& e_main = subtractor_output.e_main; + const auto& e_shadow = subtractor_output.e_shadow; + + const auto sum_of_squares = [](float a, float b) { return a + b * b; }; + const float y2 = std::accumulate(y.begin(), y.end(), 0.f, sum_of_squares); + const float e2_main = + std::accumulate(e_main.begin(), e_main.end(), 0.f, sum_of_squares); + const float e2_shadow = + std::accumulate(e_shadow.begin(), e_shadow.end(), 0.f, sum_of_squares); + + constexpr float kConvergenceThreshold = 50 * 50 * kBlockSize; + main_filter_converged_ = e2_main < 0.5f * y2 && y2 > kConvergenceThreshold; + shadow_filter_converged_ = + e2_shadow < 0.05 * y2 && y2 > kConvergenceThreshold; + main_filter_diverged_ = e2_main > 1.5f * y2 && y2 > 30.f * 30.f * kBlockSize; +} + +void SubtractorOutputAnalyzer::HandleEchoPathChange() { + shadow_filter_converged_ = false; + main_filter_converged_ = false; + main_filter_diverged_ = false; +} + +} // namespace webrtc diff --git a/modules/audio_processing/aec3/subtractor_output_analyzer.h b/modules/audio_processing/aec3/subtractor_output_analyzer.h new file mode 100644 index 0000000000..f1765cd061 --- /dev/null +++ b/modules/audio_processing/aec3/subtractor_output_analyzer.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_OUTPUT_ANALYZER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_OUTPUT_ANALYZER_H_ + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/subtractor_output.h" + +namespace webrtc { + +// Class for analyzing the properties subtractor output +class SubtractorOutputAnalyzer { + public: + SubtractorOutputAnalyzer() = default; + ~SubtractorOutputAnalyzer() = default; + + // Analyses the subtractor output. + void Update(rtc::ArrayView y, + const SubtractorOutput& subtractor_output); + + bool ConvergedFilter() const { + return main_filter_converged_ || shadow_filter_converged_; + } + + bool DivergedFilter() const { return main_filter_diverged_; } + + // Handle echo path change. + void HandleEchoPathChange(); + + private: + bool shadow_filter_converged_ = false; + bool main_filter_converged_ = false; + bool main_filter_diverged_ = false; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_OUTPUT_ANALYZER_H_ diff --git a/modules/audio_processing/aec3/subtractor_unittest.cc b/modules/audio_processing/aec3/subtractor_unittest.cc index 3c896ee0ca..35abf1cd2a 100644 --- a/modules/audio_processing/aec3/subtractor_unittest.cc +++ b/modules/audio_processing/aec3/subtractor_unittest.cc @@ -85,9 +85,8 @@ float RunSubtractorTest(int num_blocks_to_process, false, EchoPathVariability::DelayAdjustment::kNone, false)); aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(), subtractor.FilterImpulseResponse(), - subtractor.ConvergedFilter(), subtractor.DivergedFilter(), *render_delay_buffer->GetRenderBuffer(), E2_main, Y2, - output.s_main); + output, y); } const float output_power = std::inner_product( diff --git a/modules/audio_processing/aec3/suppression_gain_unittest.cc b/modules/audio_processing/aec3/suppression_gain_unittest.cc index 4670178ee2..2eeb68b15e 100644 --- a/modules/audio_processing/aec3/suppression_gain_unittest.cc +++ b/modules/audio_processing/aec3/suppression_gain_unittest.cc @@ -13,6 +13,7 @@ #include "modules/audio_processing/aec3/aec_state.h" #include "modules/audio_processing/aec3/render_delay_buffer.h" #include "modules/audio_processing/aec3/subtractor.h" +#include "modules/audio_processing/aec3/subtractor_output.h" #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/checks.h" #include "system_wrappers/include/cpu_features_wrapper.h" @@ -67,7 +68,8 @@ TEST(SuppressionGain, BasicGainComputation) { std::array R2; std::array N2; std::array g; - std::array s; + SubtractorOutput output; + std::array y; FftData E; FftData X; FftData Y; @@ -85,7 +87,8 @@ TEST(SuppressionGain, BasicGainComputation) { Y2.fill(10.f); R2.fill(0.1f); N2.fill(100.f); - s.fill(10.f); + output.Reset(); + y.fill(0.f); E.re.fill(sqrtf(E2[0])); E.im.fill(0.f); X.re.fill(sqrtf(R2[0])); @@ -97,15 +100,15 @@ TEST(SuppressionGain, BasicGainComputation) { for (int k = 0; k <= kNumBlocksPerSecond / 5 + 1; ++k) { aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(), subtractor.FilterImpulseResponse(), - subtractor.ConvergedFilter(), subtractor.DivergedFilter(), - *render_delay_buffer->GetRenderBuffer(), E2, Y2, s); + *render_delay_buffer->GetRenderBuffer(), E2, Y2, output, + y); } for (int k = 0; k < 100; ++k) { aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(), subtractor.FilterImpulseResponse(), - subtractor.ConvergedFilter(), subtractor.DivergedFilter(), - *render_delay_buffer->GetRenderBuffer(), E2, Y2, s); + *render_delay_buffer->GetRenderBuffer(), E2, Y2, output, + y); suppression_gain.GetGain(E2, R2, N2, E, X, Y, analyzer, aec_state, x, &high_bands_gain, &g); } @@ -124,8 +127,8 @@ TEST(SuppressionGain, BasicGainComputation) { for (int k = 0; k < 100; ++k) { aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(), subtractor.FilterImpulseResponse(), - subtractor.ConvergedFilter(), subtractor.DivergedFilter(), - *render_delay_buffer->GetRenderBuffer(), E2, Y2, s); + *render_delay_buffer->GetRenderBuffer(), E2, Y2, output, + y); suppression_gain.GetGain(E2, R2, N2, E, X, Y, analyzer, aec_state, x, &high_bands_gain, &g); }