From 6e5433c4d4f37133ca56bc9c16e07f39d8e29897 Mon Sep 17 00:00:00 2001 From: Sam Zackrisson Date: Fri, 18 Oct 2019 16:49:13 +0200 Subject: [PATCH] AEC3: Multi channel ERL estimator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The estimator will simply compute the worst value of all combinations of render and capture signal. This has the drawback that low-volume or silent render channels may severely misestimate the ERL. The changes have been shown to be bitexact over a large dataset. Bug: webrtc:10913 Change-Id: Id53c3ab81646ac0fab303edafc5e38892d285d8e Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/157308 Commit-Queue: Sam Zackrisson Reviewed-by: Per Ã…hgren Cr-Commit-Position: refs/heads/master@{#29542} --- modules/audio_processing/aec3/aec_state.cc | 8 +- .../aec3/aec_state_unittest.cc | 8 +- .../audio_processing/aec3/erl_estimator.cc | 61 +++++++++-- modules/audio_processing/aec3/erl_estimator.h | 9 +- .../aec3/erl_estimator_unittest.cc | 100 +++++++++++------- 5 files changed, 134 insertions(+), 52 deletions(-) diff --git a/modules/audio_processing/aec3/aec_state.cc b/modules/audio_processing/aec3/aec_state.cc index d35bed54ed..365ec9e5c7 100644 --- a/modules/audio_processing/aec3/aec_state.cc +++ b/modules/audio_processing/aec3/aec_state.cc @@ -230,11 +230,9 @@ void AecState::Update( avg_render_spectrum_with_reverb, Y2, E2_main, subtractor_output_analyzer_.ConvergedFilters()); - // TODO(bugs.webrtc.org/10913): Take all channels into account. - const auto& X2 = render_buffer.Spectrum( - delay_state_.MinDirectPathFilterDelay())[/*channel=*/0]; - erl_estimator_.Update(subtractor_output_analyzer_.ConvergedFilters()[0], X2, - Y2[0]); + erl_estimator_.Update( + subtractor_output_analyzer_.ConvergedFilters(), + render_buffer.Spectrum(delay_state_.MinDirectPathFilterDelay()), Y2); // Detect and flag echo saturation. saturation_detector_.Update(aligned_render_block, SaturatedCapture(), diff --git a/modules/audio_processing/aec3/aec_state_unittest.cc b/modules/audio_processing/aec3/aec_state_unittest.cc index b038770b11..c068b6e5f4 100644 --- a/modules/audio_processing/aec3/aec_state_unittest.cc +++ b/modules/audio_processing/aec3/aec_state_unittest.cc @@ -106,7 +106,9 @@ void RunNormalUsageTest(size_t num_render_channels, EXPECT_FALSE(state.UsableLinearEstimate()); // Verify that the active render detection works as intended. - std::fill(x[0][0].begin(), x[0][0].end(), 101.f); + for (size_t ch = 0; ch < num_render_channels; ++ch) { + std::fill(x[0][ch].begin(), x[0][ch].end(), 101.f); + } render_delay_buffer->Insert(x); for (size_t ch = 0; ch < num_capture_channels; ++ch) { subtractor_output[ch].ComputeMetrics(y[ch]); @@ -136,7 +138,9 @@ void RunNormalUsageTest(size_t num_render_channels, } } - x[0][0][0] = 5000.f; + for (size_t ch = 0; ch < num_render_channels; ++ch) { + x[0][ch][0] = 5000.f; + } for (size_t k = 0; k < render_delay_buffer->GetRenderBuffer()->GetFftBuffer().size(); ++k) { render_delay_buffer->Insert(x); diff --git a/modules/audio_processing/aec3/erl_estimator.cc b/modules/audio_processing/aec3/erl_estimator.cc index 4a0c441520..01cc33cb80 100644 --- a/modules/audio_processing/aec3/erl_estimator.cc +++ b/modules/audio_processing/aec3/erl_estimator.cc @@ -39,20 +39,69 @@ void ErlEstimator::Reset() { } void ErlEstimator::Update( - bool converged_filter, - rtc::ArrayView render_spectrum, - rtc::ArrayView capture_spectrum) { - const auto& X2 = render_spectrum; - const auto& Y2 = capture_spectrum; + const std::vector& converged_filters, + rtc::ArrayView> render_spectra, + rtc::ArrayView> + capture_spectra) { + const size_t num_capture_channels = converged_filters.size(); + RTC_DCHECK_EQ(capture_spectra.size(), num_capture_channels); // Corresponds to WGN of power -46 dBFS. constexpr float kX2Min = 44015068.0f; + const auto first_converged_iter = + std::find(converged_filters.begin(), converged_filters.end(), true); + const bool any_filter_converged = + first_converged_iter != converged_filters.end(); + if (++blocks_since_reset_ < startup_phase_length_blocks__ || - !converged_filter) { + !any_filter_converged) { return; } + // Use the maximum spectrum across capture and the maximum across render. + std::array max_capture_spectrum_data; + std::array max_capture_spectrum = + capture_spectra[/*channel=*/0]; + if (num_capture_channels > 1) { + // Initialize using the first channel with a converged filter. + const size_t first_converged = + std::distance(converged_filters.begin(), first_converged_iter); + RTC_DCHECK_GE(first_converged, 0); + RTC_DCHECK_LT(first_converged, num_capture_channels); + max_capture_spectrum_data = capture_spectra[first_converged]; + + for (size_t ch = first_converged + 1; ch < num_capture_channels; ++ch) { + if (!converged_filters[ch]) { + continue; + } + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + max_capture_spectrum_data[k] = + std::max(max_capture_spectrum_data[k], capture_spectra[ch][k]); + } + } + max_capture_spectrum = max_capture_spectrum_data; + } + + const size_t num_render_channels = render_spectra.size(); + std::array max_render_spectrum_data; + rtc::ArrayView max_render_spectrum = + render_spectra[/*channel=*/0]; + if (num_render_channels > 1) { + std::copy(render_spectra[0].begin(), render_spectra[0].end(), + max_render_spectrum_data.begin()); + for (size_t ch = 1; ch < num_render_channels; ++ch) { + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + max_render_spectrum_data[k] = + std::max(max_render_spectrum_data[k], render_spectra[ch][k]); + } + } + max_render_spectrum = max_render_spectrum_data; + } + + const auto& X2 = max_render_spectrum; + const auto& Y2 = max_capture_spectrum; + // Update the estimates in a maximum statistics manner. for (size_t k = 1; k < kFftLengthBy2; ++k) { if (X2[k] > kX2Min) { diff --git a/modules/audio_processing/aec3/erl_estimator.h b/modules/audio_processing/aec3/erl_estimator.h index 25dc39c3e6..89bf6ace36 100644 --- a/modules/audio_processing/aec3/erl_estimator.h +++ b/modules/audio_processing/aec3/erl_estimator.h @@ -14,6 +14,7 @@ #include #include +#include #include "api/array_view.h" #include "modules/audio_processing/aec3/aec3_common.h" @@ -31,9 +32,11 @@ class ErlEstimator { void Reset(); // Updates the ERL estimate. - void Update(bool converged_filter, - rtc::ArrayView render_spectrum, - rtc::ArrayView capture_spectrum); + void Update(const std::vector& converged_filters, + rtc::ArrayView> + render_spectra, + rtc::ArrayView> + capture_spectra); // Returns the most recent ERL estimate. const std::array& Erl() const { return erl_; } diff --git a/modules/audio_processing/aec3/erl_estimator_unittest.cc b/modules/audio_processing/aec3/erl_estimator_unittest.cc index 1b965d0f9a..344551dd1f 100644 --- a/modules/audio_processing/aec3/erl_estimator_unittest.cc +++ b/modules/audio_processing/aec3/erl_estimator_unittest.cc @@ -10,11 +10,19 @@ #include "modules/audio_processing/aec3/erl_estimator.h" +#include "rtc_base/strings/string_builder.h" #include "test/gtest.h" namespace webrtc { namespace { +std::string ProduceDebugText(size_t num_render_channels, + size_t num_capture_channels) { + rtc::StringBuilder ss; + ss << "Render channels: " << num_render_channels; + ss << ", Capture channels: " << num_capture_channels; + return ss.Release(); +} void VerifyErl(const std::array& erl, float erl_time_domain, @@ -28,45 +36,65 @@ void VerifyErl(const std::array& erl, // Verifies that the correct ERL estimates are achieved. TEST(ErlEstimator, Estimates) { - std::array X2; - std::array Y2; + for (size_t num_render_channels : {1, 2, 8}) { + for (size_t num_capture_channels : {1, 2, 8}) { + SCOPED_TRACE(ProduceDebugText(num_render_channels, num_capture_channels)); + std::vector> X2( + num_render_channels); + for (auto& X2_ch : X2) { + X2_ch.fill(0.f); + } + std::vector> Y2( + num_capture_channels); + for (auto& Y2_ch : Y2) { + Y2_ch.fill(0.f); + } + std::vector converged_filters(num_capture_channels, false); + const size_t converged_idx = num_capture_channels - 1; + converged_filters[converged_idx] = true; - ErlEstimator estimator(0); + ErlEstimator estimator(0); - // Verifies that the ERL estimate is properly reduced to lower values. - X2.fill(500 * 1000.f * 1000.f); - Y2.fill(10 * X2[0]); - for (size_t k = 0; k < 200; ++k) { - estimator.Update(true, X2, Y2); + // Verifies that the ERL estimate is properly reduced to lower values. + for (auto& X2_ch : X2) { + X2_ch.fill(500 * 1000.f * 1000.f); + } + Y2[converged_idx].fill(10 * X2[0][0]); + for (size_t k = 0; k < 200; ++k) { + estimator.Update(converged_filters, X2, Y2); + } + VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 10.f); + + // Verifies that the ERL is not immediately increased when the ERL in the + // data increases. + Y2[converged_idx].fill(10000 * X2[0][0]); + for (size_t k = 0; k < 998; ++k) { + estimator.Update(converged_filters, X2, Y2); + } + VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 10.f); + + // Verifies that the rate of increase is 3 dB. + estimator.Update(converged_filters, X2, Y2); + VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 20.f); + + // Verifies that the maximum ERL is achieved when there are no low RLE + // estimates. + for (size_t k = 0; k < 1000; ++k) { + estimator.Update(converged_filters, X2, Y2); + } + VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 1000.f); + + // Verifies that the ERL estimate is is not updated for low-level signals + for (auto& X2_ch : X2) { + X2_ch.fill(1000.f * 1000.f); + } + Y2[converged_idx].fill(10 * X2[0][0]); + for (size_t k = 0; k < 200; ++k) { + estimator.Update(converged_filters, X2, Y2); + } + VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 1000.f); + } } - VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 10.f); - - // Verifies that the ERL is not immediately increased when the ERL in the data - // increases. - Y2.fill(10000 * X2[0]); - for (size_t k = 0; k < 998; ++k) { - estimator.Update(true, X2, Y2); - } - VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 10.f); - - // Verifies that the rate of increase is 3 dB. - estimator.Update(true, X2, Y2); - VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 20.f); - - // Verifies that the maximum ERL is achieved when there are no low RLE - // estimates. - for (size_t k = 0; k < 1000; ++k) { - estimator.Update(true, X2, Y2); - } - VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 1000.f); - - // Verifies that the ERL estimate is is not updated for low-level signals - X2.fill(1000.f * 1000.f); - Y2.fill(10 * X2[0]); - for (size_t k = 0; k < 200; ++k) { - estimator.Update(true, X2, Y2); - } - VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 1000.f); } } // namespace webrtc