From b24ebc535b3602d682efce517bf5e4fbeae36e30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20de=20Vicente=20Pe=C3=B1a?= Date: Mon, 31 Oct 2022 12:25:01 +0100 Subject: [PATCH] pre echo delay: adding different options for detecting pre echoes. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug: webrtc:14205 Change-Id: I9de13c8525914278a2961bd1193b1ce2472c8c02 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/280900 Reviewed-by: Per Ã…hgren Commit-Queue: Jesus de Vicente Pena Reviewed-by: Lionel Koenig Cr-Commit-Position: refs/heads/main@{#38511} --- modules/audio_processing/aec3/BUILD.gn | 1 + .../audio_processing/aec3/matched_filter.cc | 97 ++++++++++++++++--- .../audio_processing/aec3/matched_filter.h | 16 +++ .../aec3/matched_filter_unittest.cc | 54 +++++++++++ 4 files changed, 155 insertions(+), 13 deletions(-) diff --git a/modules/audio_processing/aec3/BUILD.gn b/modules/audio_processing/aec3/BUILD.gn index 679ce48747..f5eb5d5951 100644 --- a/modules/audio_processing/aec3/BUILD.gn +++ b/modules/audio_processing/aec3/BUILD.gn @@ -228,6 +228,7 @@ rtc_source_set("matched_filter") { deps = [ ":aec3_common", "../../../api:array_view", + "../../../rtc_base:gtest_prod", "../../../rtc_base/system:arch", ] absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] diff --git a/modules/audio_processing/aec3/matched_filter.cc b/modules/audio_processing/aec3/matched_filter.cc index c5e394ad2f..a9054825c6 100644 --- a/modules/audio_processing/aec3/matched_filter.cc +++ b/modules/audio_processing/aec3/matched_filter.cc @@ -29,7 +29,9 @@ #include "modules/audio_processing/aec3/downsampled_render_buffer.h" #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/checks.h" +#include "rtc_base/experiments/field_trial_parser.h" #include "rtc_base/logging.h" +#include "system_wrappers/include/field_trial.h" namespace { @@ -53,23 +55,89 @@ void UpdateAccumulatedError( } } -size_t ComputePreEchoLag(const rtc::ArrayView accumulated_error, - size_t lag, - size_t alignment_shift_winner) { +size_t ComputePreEchoLag( + const webrtc::MatchedFilter::PreEchoConfiguration& pre_echo_configuration, + const rtc::ArrayView accumulated_error, + size_t lag, + size_t alignment_shift_winner) { + RTC_DCHECK_GE(lag, alignment_shift_winner); size_t pre_echo_lag_estimate = lag - alignment_shift_winner; size_t maximum_pre_echo_lag = std::min(pre_echo_lag_estimate / kAccumulatedErrorSubSampleRate, accumulated_error.size()); - for (size_t k = 1; k < maximum_pre_echo_lag; ++k) { - if (accumulated_error[k] < 0.5f * accumulated_error[k - 1] && - accumulated_error[k] < 0.5f) { - pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1; + switch (pre_echo_configuration.mode) { + case 0: + // Mode 0: Pre echo lag is defined as the first coefficient with an error + // lower than a threshold with a certain decrease slope. + for (size_t k = 1; k < maximum_pre_echo_lag; ++k) { + if (accumulated_error[k] < + pre_echo_configuration.threshold * accumulated_error[k - 1] && + accumulated_error[k] < pre_echo_configuration.threshold) { + pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1; + break; + } + } + break; + case 1: + // Mode 1: Pre echo lag is defined as the first coefficient with an error + // lower than a certain threshold. + for (size_t k = 0; k < maximum_pre_echo_lag; ++k) { + if (accumulated_error[k] < pre_echo_configuration.threshold) { + pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1; + break; + } + } + break; + case 2: + // Mode 2: Pre echo lag is defined as the closest coefficient to the lag + // with an error lower than a certain threshold. + for (int k = static_cast(maximum_pre_echo_lag) - 1; k >= 0; --k) { + if (accumulated_error[k] > pre_echo_configuration.threshold) { + break; + } + pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1; + } + break; + default: + RTC_DCHECK_NOTREACHED(); break; - } } return pre_echo_lag_estimate + alignment_shift_winner; } +webrtc::MatchedFilter::PreEchoConfiguration FetchPreEchoConfiguration() { + float threshold = 0.5f; + int mode = 0; + const std::string pre_echo_configuration_field_trial = + webrtc::field_trial::FindFullName("WebRTC-Aec3PreEchoConfiguration"); + webrtc::FieldTrialParameter threshold_field_trial_parameter( + /*key=*/"threshold", /*default_value=*/threshold); + webrtc::FieldTrialParameter mode_field_trial_parameter( + /*key=*/"mode", /*default_value=*/mode); + webrtc::ParseFieldTrial( + {&threshold_field_trial_parameter, &mode_field_trial_parameter}, + pre_echo_configuration_field_trial); + float threshold_read = + static_cast(threshold_field_trial_parameter.Get()); + int mode_read = mode_field_trial_parameter.Get(); + if (threshold_read < 1.0f && threshold_read > 0.0f) { + threshold = threshold_read; + } else { + RTC_LOG(LS_ERROR) + << "AEC3: Pre echo configuration: wrong input, threshold = " + << threshold_read << "."; + } + if (mode_read >= 0 && mode_read <= 3) { + mode = mode_read; + } else { + RTC_LOG(LS_ERROR) << "AEC3: Pre echo configuration: wrong input, mode = " + << mode_read << "."; + } + RTC_LOG(LS_INFO) << "AEC3: Pre echo configuration: threshold = " << threshold + << ", mode = " << mode << "."; + return {.threshold = threshold, .mode = mode}; +} + } // namespace namespace webrtc { @@ -612,7 +680,8 @@ MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper, smoothing_fast_(smoothing_fast), smoothing_slow_(smoothing_slow), matching_filter_threshold_(matching_filter_threshold), - detect_pre_echo_(detect_pre_echo) { + detect_pre_echo_(detect_pre_echo), + pre_echo_config_(FetchPreEchoConfiguration()) { RTC_DCHECK(data_dumper); RTC_DCHECK_LT(0, window_size_sub_blocks); RTC_DCHECK((kBlockSize % sub_block_size) == 0); @@ -753,7 +822,8 @@ void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer, 1.0f / error_sum_anchor); } reported_lag_estimate_->pre_echo_lag = ComputePreEchoLag( - accumulated_error_[winner_index], winner_lag_.value(), + pre_echo_config_, accumulated_error_[winner_index], + winner_lag_.value(), winner_index * filter_intra_lag_shift_ /*alignment_shift_winner*/); } last_detected_best_lag_filter_ = winner_index; @@ -794,9 +864,10 @@ void MatchedFilter::Dump() { "aec3_correlator_error_" + std::to_string(n) + "_h"; data_dumper_->DumpRaw(dumper_error.c_str(), accumulated_error_[n]); - size_t pre_echo_lag = ComputePreEchoLag( - accumulated_error_[n], lag_estimate + n * filter_intra_lag_shift_, - n * filter_intra_lag_shift_); + size_t pre_echo_lag = + ComputePreEchoLag(pre_echo_config_, accumulated_error_[n], + lag_estimate + n * filter_intra_lag_shift_, + n * filter_intra_lag_shift_); std::string dumper_pre_lag = "aec3_correlator_pre_echo_lag_" + std::to_string(n); data_dumper_->DumpRaw(dumper_pre_lag.c_str(), pre_echo_lag); diff --git a/modules/audio_processing/aec3/matched_filter.h b/modules/audio_processing/aec3/matched_filter.h index 760d5e39fd..1560fb02f1 100644 --- a/modules/audio_processing/aec3/matched_filter.h +++ b/modules/audio_processing/aec3/matched_filter.h @@ -18,6 +18,7 @@ #include "absl/types/optional.h" #include "api/array_view.h" #include "modules/audio_processing/aec3/aec3_common.h" +#include "rtc_base/gtest_prod_util.h" #include "rtc_base/system/arch.h" namespace webrtc { @@ -105,6 +106,11 @@ class MatchedFilter { size_t pre_echo_lag = 0; }; + struct PreEchoConfiguration { + const float threshold; + const int mode; + }; + MatchedFilter(ApmDataDumper* data_dumper, Aec3Optimization optimization, size_t sub_block_size, @@ -147,6 +153,15 @@ class MatchedFilter { size_t downsampling_factor) const; private: + FRIEND_TEST_ALL_PREFIXES(MatchedFilterFieldTrialTest, + PreEchoConfigurationTest); + FRIEND_TEST_ALL_PREFIXES(MatchedFilterFieldTrialTest, + WrongPreEchoConfigurationTest); + + // Only for testing. Gets the pre echo detection configuration. + const PreEchoConfiguration& GetPreEchoConfiguration() const { + return pre_echo_config_; + } void Dump(); ApmDataDumper* const data_dumper_; @@ -166,6 +181,7 @@ class MatchedFilter { const float smoothing_slow_; const float matching_filter_threshold_; const bool detect_pre_echo_; + const PreEchoConfiguration pre_echo_config_; }; } // namespace webrtc diff --git a/modules/audio_processing/aec3/matched_filter_unittest.cc b/modules/audio_processing/aec3/matched_filter_unittest.cc index b080308191..0a04c7809c 100644 --- a/modules/audio_processing/aec3/matched_filter_unittest.cc +++ b/modules/audio_processing/aec3/matched_filter_unittest.cc @@ -27,6 +27,7 @@ #include "rtc_base/random.h" #include "rtc_base/strings/string_builder.h" #include "system_wrappers/include/cpu_features_wrapper.h" +#include "test/field_trial.h" #include "test/gtest.h" namespace webrtc { @@ -555,4 +556,57 @@ INSTANTIATE_TEST_SUITE_P(_, #endif } // namespace aec3 + +TEST(MatchedFilterFieldTrialTest, PreEchoConfigurationTest) { + float threshold_in = 0.1f; + int mode_in = 2; + rtc::StringBuilder field_trial_name; + field_trial_name << "WebRTC-Aec3PreEchoConfiguration/threshold:" + << threshold_in << ",mode:" << mode_in << "/"; + webrtc::test::ScopedFieldTrials field_trials(field_trial_name.str()); + ApmDataDumper data_dumper(0); + EchoCanceller3Config config; + MatchedFilter matched_filter( + &data_dumper, DetectOptimization(), + kBlockSize / config.delay.down_sampling_factor, + aec3::kWindowSizeSubBlocks, aec3::kNumMatchedFilters, + aec3::kAlignmentShiftSubBlocks, + config.render_levels.poor_excitation_render_limit, + config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, + config.delay.delay_candidate_detection_threshold, + config.delay.detect_pre_echo); + + auto& pre_echo_config = matched_filter.GetPreEchoConfiguration(); + EXPECT_EQ(pre_echo_config.threshold, threshold_in); + EXPECT_EQ(pre_echo_config.mode, mode_in); +} + +TEST(MatchedFilterFieldTrialTest, WrongPreEchoConfigurationTest) { + constexpr float kDefaultThreshold = 0.5f; + constexpr int kDefaultMode = 0; + float threshold_in = -0.1f; + int mode_in = 5; + rtc::StringBuilder field_trial_name; + field_trial_name << "WebRTC-Aec3PreEchoConfiguration/threshold:" + << threshold_in << ",mode:" << mode_in << "/"; + webrtc::test::ScopedFieldTrials field_trials(field_trial_name.str()); + ApmDataDumper data_dumper(0); + EchoCanceller3Config config; + MatchedFilter matched_filter( + &data_dumper, DetectOptimization(), + kBlockSize / config.delay.down_sampling_factor, + aec3::kWindowSizeSubBlocks, aec3::kNumMatchedFilters, + aec3::kAlignmentShiftSubBlocks, + config.render_levels.poor_excitation_render_limit, + config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, + config.delay.delay_candidate_detection_threshold, + config.delay.detect_pre_echo); + + auto& pre_echo_config = matched_filter.GetPreEchoConfiguration(); + EXPECT_EQ(pre_echo_config.threshold, kDefaultThreshold); + EXPECT_EQ(pre_echo_config.mode, kDefaultMode); +} + } // namespace webrtc