From 219208991b63b868d38d4c78fc3b04fa8370b636 Mon Sep 17 00:00:00 2001 From: peah Date: Wed, 8 Feb 2017 05:08:56 -0800 Subject: [PATCH] Adding full initial version of delay estimation functionality in echo canceller 3 This CL adds code to the all the delay estimation functionality that is available for the first version of echo canceller 3. The code completes the class EchoPathDelayEstimator. Note that this code does not yet include any handling of clock-drift so there will be upcoming versions of this code. Also note that the CL includes some minor changes in other files for echo canceller 3. BUG=webrtc:6018 Review-Url: https://codereview.webrtc.org/2644123002 Cr-Commit-Position: refs/heads/master@{#16489} --- webrtc/modules/audio_processing/BUILD.gn | 9 + .../audio_processing/aec3/aec3_constants.h | 16 ++ .../audio_processing/aec3/block_processor.cc | 4 +- .../aec3/block_processor_unittest.cc | 6 + .../audio_processing/aec3/decimator_by_4.cc | 44 ++++ .../audio_processing/aec3/decimator_by_4.h | 39 ++++ .../aec3/decimator_by_4_unittest.cc | 127 ++++++++++++ .../audio_processing/aec3/echo_canceller3.cc | 2 + .../aec3/echo_canceller3_unittest.cc | 6 + .../aec3/echo_path_delay_estimator.cc | 60 +++++- .../aec3/echo_path_delay_estimator.h | 14 +- .../echo_path_delay_estimator_unittest.cc | 134 ++++++++---- .../audio_processing/aec3/echo_remover.cc | 3 +- .../audio_processing/aec3/matched_filter.cc | 172 ++++++++++++++++ .../audio_processing/aec3/matched_filter.h | 84 ++++++++ .../aec3/matched_filter_lag_aggregator.cc | 66 ++++++ .../aec3/matched_filter_lag_aggregator.h | 46 +++++ .../matched_filter_lag_aggregator_unittest.cc | 192 ++++++++++++++++++ .../aec3/matched_filter_unittest.cc | 190 +++++++++++++++++ .../aec3/render_delay_controller.cc | 5 +- .../aec3/render_delay_controller_unittest.cc | 6 +- 21 files changed, 1171 insertions(+), 54 deletions(-) create mode 100644 webrtc/modules/audio_processing/aec3/decimator_by_4.cc create mode 100644 webrtc/modules/audio_processing/aec3/decimator_by_4.h create mode 100644 webrtc/modules/audio_processing/aec3/decimator_by_4_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/matched_filter.cc create mode 100644 webrtc/modules/audio_processing/aec3/matched_filter.h create mode 100644 webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc create mode 100644 webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h create mode 100644 webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc diff --git a/webrtc/modules/audio_processing/BUILD.gn b/webrtc/modules/audio_processing/BUILD.gn index dfedbee11c..2684a782c3 100644 --- a/webrtc/modules/audio_processing/BUILD.gn +++ b/webrtc/modules/audio_processing/BUILD.gn @@ -33,6 +33,8 @@ rtc_static_library("audio_processing") { "aec3/block_processor.h", "aec3/cascaded_biquad_filter.cc", "aec3/cascaded_biquad_filter.h", + "aec3/decimator_by_4.cc", + "aec3/decimator_by_4.h", "aec3/echo_canceller3.cc", "aec3/echo_canceller3.h", "aec3/echo_path_delay_estimator.cc", @@ -42,6 +44,10 @@ rtc_static_library("audio_processing") { "aec3/echo_remover.h", "aec3/frame_blocker.cc", "aec3/frame_blocker.h", + "aec3/matched_filter.cc", + "aec3/matched_filter.h", + "aec3/matched_filter_lag_aggregator.cc", + "aec3/matched_filter_lag_aggregator.h", "aec3/render_delay_buffer.cc", "aec3/render_delay_buffer.h", "aec3/render_delay_controller.cc", @@ -519,10 +525,13 @@ if (rtc_include_tests) { "aec3/block_framer_unittest.cc", "aec3/block_processor_unittest.cc", "aec3/cascaded_biquad_filter_unittest.cc", + "aec3/decimator_by_4_unittest.cc", "aec3/echo_canceller3_unittest.cc", "aec3/echo_path_delay_estimator_unittest.cc", "aec3/echo_remover_unittest.cc", "aec3/frame_blocker_unittest.cc", + "aec3/matched_filter_lag_aggregator_unittest.cc", + "aec3/matched_filter_unittest.cc", "aec3/mock/mock_block_processor.h", "aec3/mock/mock_echo_remover.h", "aec3/mock/mock_render_delay_buffer.h", diff --git a/webrtc/modules/audio_processing/aec3/aec3_constants.h b/webrtc/modules/audio_processing/aec3/aec3_constants.h index 946e50c786..054b0d8afd 100644 --- a/webrtc/modules/audio_processing/aec3/aec3_constants.h +++ b/webrtc/modules/audio_processing/aec3/aec3_constants.h @@ -34,6 +34,11 @@ constexpr int LowestBandRate(int sample_rate_hz) { return sample_rate_hz == 8000 ? sample_rate_hz : 16000; } +constexpr bool ValidFullBandRate(int sample_rate_hz) { + return sample_rate_hz == 8000 || sample_rate_hz == 16000 || + sample_rate_hz == 32000 || sample_rate_hz == 48000; +} + static_assert(1 == NumBandsForRate(8000), "Number of bands for 8 kHz"); static_assert(1 == NumBandsForRate(16000), "Number of bands for 16 kHz"); static_assert(2 == NumBandsForRate(32000), "Number of bands for 32 kHz"); @@ -47,6 +52,17 @@ static_assert(16000 == LowestBandRate(32000), static_assert(16000 == LowestBandRate(48000), "Sample rate of band 0 for 48 kHz"); +static_assert(ValidFullBandRate(8000), + "Test that 8 kHz is a valid sample rate"); +static_assert(ValidFullBandRate(16000), + "Test that 16 kHz is a valid sample rate"); +static_assert(ValidFullBandRate(32000), + "Test that 32 kHz is a valid sample rate"); +static_assert(ValidFullBandRate(48000), + "Test that 48 kHz is a valid sample rate"); +static_assert(!ValidFullBandRate(8001), + "Test that 8001 Hz is not a valid sample rate"); + } // namespace webrtc #endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_AEC3_CONSTANTS_H_ diff --git a/webrtc/modules/audio_processing/aec3/block_processor.cc b/webrtc/modules/audio_processing/aec3/block_processor.cc index 1f50a84854..550a21073f 100644 --- a/webrtc/modules/audio_processing/aec3/block_processor.cc +++ b/webrtc/modules/audio_processing/aec3/block_processor.cc @@ -61,7 +61,9 @@ BlockProcessorImpl::BlockProcessorImpl( sample_rate_hz_(sample_rate_hz), render_buffer_(std::move(render_buffer)), delay_controller_(std::move(delay_controller)), - echo_remover_(std::move(echo_remover)) {} + echo_remover_(std::move(echo_remover)) { + RTC_DCHECK(ValidFullBandRate(sample_rate_hz_)); +} BlockProcessorImpl::~BlockProcessorImpl() = default; diff --git a/webrtc/modules/audio_processing/aec3/block_processor_unittest.cc b/webrtc/modules/audio_processing/aec3/block_processor_unittest.cc index b5d6a1432f..ac3af6afb1 100644 --- a/webrtc/modules/audio_processing/aec3/block_processor_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/block_processor_unittest.cc @@ -252,6 +252,12 @@ TEST(BlockProcessor, NullBufferRenderParameter) { ""); } +// Verifies the check for correct sample rate. +TEST(BlockProcessor, WrongSampleRate) { + EXPECT_DEATH(std::unique_ptr(BlockProcessor::Create(8001)), + ""); +} + #endif } // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/decimator_by_4.cc b/webrtc/modules/audio_processing/aec3/decimator_by_4.cc new file mode 100644 index 0000000000..3f4c858ec3 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/decimator_by_4.cc @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "webrtc/modules/audio_processing/aec3/decimator_by_4.h" + +#include "webrtc/base/checks.h" + +namespace webrtc { +namespace { + +// [B,A] = butter(2,1500/16000) which are the same as [B,A] = +// butter(2,750/8000). +const CascadedBiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients = { + {0.0179f, 0.0357f, 0.0179f}, + {-1.5879f, 0.6594f}}; + +} // namespace + +DecimatorBy4::DecimatorBy4() + : low_pass_filter_(kLowPassFilterCoefficients, 3) {} + +void DecimatorBy4::Decimate(rtc::ArrayView in, + std::array* out) { + RTC_DCHECK_EQ(kBlockSize, in.size()); + RTC_DCHECK(out); + std::array x; + + // Limit the frequency content of the signal to avoid aliasing. + low_pass_filter_.Process(in, x); + + // Downsample the signal. + for (size_t j = 0, k = 0; j < out->size(); ++j, k += 4) { + RTC_DCHECK_GT(kBlockSize, k); + (*out)[j] = x[k]; + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/decimator_by_4.h b/webrtc/modules/audio_processing/aec3/decimator_by_4.h new file mode 100644 index 0000000000..076c1688c8 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/decimator_by_4.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_DECIMATOR_BY_4_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_DECIMATOR_BY_4_H_ + +#include + +#include "webrtc/base/array_view.h" +#include "webrtc/base/constructormagic.h" +#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/cascaded_biquad_filter.h" + +namespace webrtc { + +// Provides functionality for decimating a signal by 4. +class DecimatorBy4 { + public: + DecimatorBy4(); + + // Downsamples the signal. + void Decimate(rtc::ArrayView in, + std::array* out); + + private: + CascadedBiQuadFilter low_pass_filter_; + + RTC_DISALLOW_COPY_AND_ASSIGN(DecimatorBy4); +}; +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_DECIMATOR_BY_4_H_ diff --git a/webrtc/modules/audio_processing/aec3/decimator_by_4_unittest.cc b/webrtc/modules/audio_processing/aec3/decimator_by_4_unittest.cc new file mode 100644 index 0000000000..a7699ba64f --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/decimator_by_4_unittest.cc @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/decimator_by_4.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { + +namespace { + +std::string ProduceDebugText(int sample_rate_hz) { + std::ostringstream ss; + ss << "Sample rate: " << sample_rate_hz; + return ss.str(); +} + +constexpr float kPi = 3.141592f; +constexpr size_t kNumStartupBlocks = 50; +constexpr size_t kNumBlocks = 1000; + +void ProduceDecimatedSinusoidalOutputPower(int sample_rate_hz, + float sinusoidal_frequency_hz, + float* input_power, + float* output_power) { + float input[kBlockSize * kNumBlocks]; + + // Produce a sinusoid of the specified frequency. + for (size_t k = 0; k < kBlockSize * kNumBlocks; ++k) { + input[k] = + 32767.f * sin(2.f * kPi * sinusoidal_frequency_hz * k / sample_rate_hz); + } + + DecimatorBy4 decimator; + std::array output; + + for (size_t k = 0; k < kNumBlocks; ++k) { + std::array sub_block; + + decimator.Decimate( + rtc::ArrayView(&input[k * kBlockSize], kBlockSize), + &sub_block); + + std::copy(sub_block.begin(), sub_block.end(), + output.begin() + k * kSubBlockSize); + } + + ASSERT_GT(kNumBlocks, kNumStartupBlocks); + rtc::ArrayView input_to_evaluate( + &input[kNumStartupBlocks * kBlockSize], + (kNumBlocks - kNumStartupBlocks) * kBlockSize); + rtc::ArrayView output_to_evaluate( + &output[kNumStartupBlocks * kSubBlockSize], + (kNumBlocks - kNumStartupBlocks) * kSubBlockSize); + *input_power = + std::inner_product(input_to_evaluate.begin(), input_to_evaluate.end(), + input_to_evaluate.begin(), 0.f) / + input_to_evaluate.size(); + *output_power = + std::inner_product(output_to_evaluate.begin(), output_to_evaluate.end(), + output_to_evaluate.begin(), 0.f) / + output_to_evaluate.size(); +} + +} // namespace + +// Verifies that there is little aliasing from upper frequencies in the +// downsampling. +TEST(DecimatorBy4, NoLeakageFromUpperFrequencies) { + float input_power; + float output_power; + for (auto rate : {8000, 16000, 32000, 48000}) { + ProduceDebugText(rate); + ProduceDecimatedSinusoidalOutputPower(rate, 3.f / 8.f * rate, &input_power, + &output_power); + EXPECT_GT(0.0001f * input_power, output_power); + } +} + +// Verifies that the impact of low-frequency content is small during the +// downsampling. +TEST(DecimatorBy4, NoImpactOnLowerFrequencies) { + float input_power; + float output_power; + for (auto rate : {8000, 16000, 32000, 48000}) { + ProduceDebugText(rate); + ProduceDecimatedSinusoidalOutputPower(rate, 200.f, &input_power, + &output_power); + EXPECT_LT(0.7f * input_power, output_power); + } +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) +// Verifies the check for the input size. +TEST(DecimatorBy4, WrongInputSize) { + DecimatorBy4 decimator; + std::vector x(std::vector(kBlockSize - 1, 0.f)); + std::array x_downsampled; + EXPECT_DEATH(decimator.Decimate(x, &x_downsampled), ""); +} + +// Verifies the check for non-null output parameter. +TEST(DecimatorBy4, NullOutput) { + DecimatorBy4 decimator; + std::vector x(std::vector(kBlockSize, 0.f)); + EXPECT_DEATH(decimator.Decimate(x, nullptr), ""); +} + +#endif + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/echo_canceller3.cc b/webrtc/modules/audio_processing/aec3/echo_canceller3.cc index ec7b55a215..a4b796d969 100644 --- a/webrtc/modules/audio_processing/aec3/echo_canceller3.cc +++ b/webrtc/modules/audio_processing/aec3/echo_canceller3.cc @@ -227,6 +227,8 @@ EchoCanceller3::EchoCanceller3(int sample_rate_hz, std::vector(frame_length_, 0.f)), block_(num_bands_, std::vector(kBlockSize, 0.f)), sub_frame_view_(num_bands_) { + RTC_DCHECK(ValidFullBandRate(sample_rate_hz_)); + std::unique_ptr render_highpass_filter; if (use_highpass_filter) { render_highpass_filter.reset(new CascadedBiQuadFilter( diff --git a/webrtc/modules/audio_processing/aec3/echo_canceller3_unittest.cc b/webrtc/modules/audio_processing/aec3/echo_canceller3_unittest.cc index afe429d1d2..8ccaa51e52 100644 --- a/webrtc/modules/audio_processing/aec3/echo_canceller3_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/echo_canceller3_unittest.cc @@ -719,6 +719,12 @@ TEST(EchoCanceller3InputCheck, NullCaptureProcessingParameter) { EXPECT_DEATH(EchoCanceller3(8000, false).ProcessCapture(nullptr, false), ""); } +// Verifies the check for correct sample rate. +TEST(EchoCanceller3InputCheck, WrongSampleRate) { + ApmDataDumper data_dumper(0); + EXPECT_DEATH(EchoCanceller3(8001, false), ""); +} + #endif } // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.cc b/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.cc index 3ae8f879ef..539832df03 100644 --- a/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.cc +++ b/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.cc @@ -9,29 +9,71 @@ */ #include "webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.h" +#include +#include + #include "webrtc/base/checks.h" #include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/include/audio_processing.h" #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" namespace webrtc { -// TODO(peah): Add functionality. -EchoPathDelayEstimator::EchoPathDelayEstimator(ApmDataDumper* data_dumper, - int sample_rate_hz) { +namespace { + +constexpr size_t kNumMatchedFilters = 4; +constexpr size_t kMatchedFilterWindowSizeSubBlocks = 32; +constexpr size_t kMatchedFilterAlignmentShiftSizeSubBlocks = + kMatchedFilterWindowSizeSubBlocks * 3 / 4; + +constexpr int kDownSamplingFactor = 4; +} // namespace + +EchoPathDelayEstimator::EchoPathDelayEstimator(ApmDataDumper* data_dumper) + : data_dumper_(data_dumper), + matched_filter_(data_dumper_, + kMatchedFilterWindowSizeSubBlocks, + kNumMatchedFilters, + kMatchedFilterAlignmentShiftSizeSubBlocks), + matched_filter_lag_aggregator_(data_dumper_, + matched_filter_.NumLagEstimates()) { RTC_DCHECK(data_dumper); - RTC_DCHECK(sample_rate_hz == 8000 || sample_rate_hz == 16000 || - sample_rate_hz == 32000 || sample_rate_hz == 48000); } EchoPathDelayEstimator::~EchoPathDelayEstimator() = default; -// TODO(peah): Add functionality. rtc::Optional EchoPathDelayEstimator::EstimateDelay( rtc::ArrayView render, rtc::ArrayView capture) { - RTC_DCHECK_EQ(render.size(), kBlockSize); - RTC_DCHECK_EQ(capture.size(), kBlockSize); - return rtc::Optional(); + RTC_DCHECK_EQ(kBlockSize, capture.size()); + RTC_DCHECK_EQ(render.size(), capture.size()); + + std::array downsampled_render; + std::array downsampled_capture; + + render_decimator_.Decimate(render, &downsampled_render); + capture_decimator_.Decimate(capture, &downsampled_capture); + + matched_filter_.Update(downsampled_render, downsampled_capture); + + rtc::Optional aggregated_matched_filter_lag = + matched_filter_lag_aggregator_.Aggregate( + matched_filter_.GetLagEstimates()); + + // TODO(peah): Move this logging outside of this class once EchoCanceller3 + // development is done. + data_dumper_->DumpRaw("aec3_echo_path_delay_estimator_delay", + aggregated_matched_filter_lag + ? static_cast(*aggregated_matched_filter_lag * + kDownSamplingFactor) + : -1); + + // Return the detected delay in samples as the aggregated matched filter lag + // compensated by the down sampling factor for the signal being correlated. + return aggregated_matched_filter_lag + ? rtc::Optional(*aggregated_matched_filter_lag * + kDownSamplingFactor) + : rtc::Optional(); } } // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.h b/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.h index ae8ab4318e..bbe9a7c32a 100644 --- a/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.h +++ b/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.h @@ -15,19 +15,31 @@ #include "webrtc/base/constructormagic.h" #include "webrtc/base/optional.h" +#include "webrtc/modules/audio_processing/aec3/matched_filter.h" +#include "webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h" +#include "webrtc/modules/audio_processing/aec3/decimator_by_4.h" namespace webrtc { class ApmDataDumper; +// Estimates the delay of the echo path. class EchoPathDelayEstimator { public: - EchoPathDelayEstimator(ApmDataDumper* data_dumper, int sample_rate_hz); + explicit EchoPathDelayEstimator(ApmDataDumper* data_dumper); ~EchoPathDelayEstimator(); + + // Produce a delay estimate if such is avaliable. rtc::Optional EstimateDelay(rtc::ArrayView render, rtc::ArrayView capture); private: + ApmDataDumper* const data_dumper_; + DecimatorBy4 render_decimator_; + DecimatorBy4 capture_decimator_; + MatchedFilter matched_filter_; + MatchedFilterLagAggregator matched_filter_lag_aggregator_; + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(EchoPathDelayEstimator); }; } // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc b/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc index c3146ed7b8..ba9ff23540 100644 --- a/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc @@ -10,19 +10,22 @@ #include "webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.h" +#include #include #include +#include "webrtc/base/random.h" #include "webrtc/modules/audio_processing/aec3/aec3_constants.h" #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" +#include "webrtc/modules/audio_processing/test/echo_canceller_test_tools.h" #include "webrtc/test/gtest.h" namespace webrtc { namespace { -std::string ProduceDebugText(int sample_rate_hz) { +std::string ProduceDebugText(size_t delay) { std::ostringstream ss; - ss << "Sample rate: " << sample_rate_hz; + ss << "Delay: " << delay; return ss.str(); } @@ -30,57 +33,120 @@ std::string ProduceDebugText(int sample_rate_hz) { // Verifies that the basic API calls work. TEST(EchoPathDelayEstimator, BasicApiCalls) { - for (auto rate : {8000, 16000, 32000, 48000}) { - ProduceDebugText(rate); - ApmDataDumper data_dumper(0); - EchoPathDelayEstimator estimator(&data_dumper, rate); - std::vector render(kBlockSize, 0.f); - std::vector capture(kBlockSize, 0.f); - for (size_t k = 0; k < 100; ++k) { - estimator.EstimateDelay(render, capture); + ApmDataDumper data_dumper(0); + EchoPathDelayEstimator estimator(&data_dumper); + std::vector render(kBlockSize, 0.f); + std::vector capture(kBlockSize, 0.f); + for (size_t k = 0; k < 100; ++k) { + estimator.EstimateDelay(render, capture); + } +} + +// Verifies that the delay estimator produces correct delay for artificially +// delayed signals. +TEST(EchoPathDelayEstimator, DelayEstimation) { + Random random_generator(42U); + std::vector render(kBlockSize, 0.f); + std::vector capture(kBlockSize, 0.f); + ApmDataDumper data_dumper(0); + for (size_t delay_samples : {0, 64, 150, 200, 800, 4000}) { + SCOPED_TRACE(ProduceDebugText(delay_samples)); + DelayBuffer signal_delay_buffer(delay_samples); + EchoPathDelayEstimator estimator(&data_dumper); + + rtc::Optional estimated_delay_samples; + for (size_t k = 0; k < (100 + delay_samples / kBlockSize); ++k) { + RandomizeSampleVector(&random_generator, render); + signal_delay_buffer.Delay(render, capture); + estimated_delay_samples = estimator.EstimateDelay(render, capture); } + if (estimated_delay_samples) { + // Due to the internal down-sampling by 4 done inside the delay estimator + // the estimated delay cannot be expected to be closer than 4 samples to + // the true delay. + EXPECT_NEAR(delay_samples, *estimated_delay_samples, 4); + } else { + ADD_FAILURE(); + } + } +} + +// Verifies that the delay estimator does not produce delay estimates too +// quickly. +TEST(EchoPathDelayEstimator, NoInitialDelayestimates) { + Random random_generator(42U); + std::vector render(kBlockSize, 0.f); + std::vector capture(kBlockSize, 0.f); + ApmDataDumper data_dumper(0); + + EchoPathDelayEstimator estimator(&data_dumper); + for (size_t k = 0; k < 19; ++k) { + RandomizeSampleVector(&random_generator, render); + std::copy(render.begin(), render.end(), capture.begin()); + EXPECT_FALSE(estimator.EstimateDelay(render, capture)); + } +} + +// Verifies that the delay estimator does not produce delay estimates for render +// signals of low level. +TEST(EchoPathDelayEstimator, NoDelayEstimatesForLowLevelRenderSignals) { + Random random_generator(42U); + std::vector render(kBlockSize, 0.f); + std::vector capture(kBlockSize, 0.f); + ApmDataDumper data_dumper(0); + EchoPathDelayEstimator estimator(&data_dumper); + for (size_t k = 0; k < 100; ++k) { + RandomizeSampleVector(&random_generator, render); + for (auto& render_k : render) { + render_k *= 100.f / 32767.f; + } + std::copy(render.begin(), render.end(), capture.begin()); + EXPECT_FALSE(estimator.EstimateDelay(render, capture)); + } +} + +// Verifies that the delay estimator does not produce delay estimates for +// uncorrelated signals. +TEST(EchoPathDelayEstimator, NoDelayEstimatesForUncorrelatedSignals) { + Random random_generator(42U); + std::vector render(kBlockSize, 0.f); + std::vector capture(kBlockSize, 0.f); + ApmDataDumper data_dumper(0); + EchoPathDelayEstimator estimator(&data_dumper); + for (size_t k = 0; k < 100; ++k) { + RandomizeSampleVector(&random_generator, render); + RandomizeSampleVector(&random_generator, capture); + EXPECT_FALSE(estimator.EstimateDelay(render, capture)); } } #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) -// Verifies the check for correct sample rate. -TEST(EchoPathDelayEstimator, WrongSampleRate) { - ApmDataDumper data_dumper(0); - EXPECT_DEATH(EchoPathDelayEstimator remover(&data_dumper, 8001), ""); -} - // Verifies the check for the render blocksize. // TODO(peah): Re-enable the test once the issue with memory leaks during DEATH // tests on test bots has been fixed. TEST(EchoPathDelayEstimator, DISABLED_WrongRenderBlockSize) { - for (auto rate : {8000, 16000, 32000, 48000}) { - ProduceDebugText(rate); - ApmDataDumper data_dumper(0); - EchoPathDelayEstimator estimator(&data_dumper, rate); - std::vector render(kBlockSize - 1, 0.f); - std::vector capture(kBlockSize, 0.f); - EXPECT_DEATH(estimator.EstimateDelay(render, capture), ""); - } + ApmDataDumper data_dumper(0); + EchoPathDelayEstimator estimator(&data_dumper); + std::vector render(std::vector(kBlockSize - 1, 0.f)); + std::vector capture(std::vector(kBlockSize, 0.f)); + EXPECT_DEATH(estimator.EstimateDelay(render, capture), ""); } // Verifies the check for the capture blocksize. // TODO(peah): Re-enable the test once the issue with memory leaks during DEATH // tests on test bots has been fixed. -TEST(EchoPathDelayEstimator, DISABLED_WrongCaptureBlockSize) { - for (auto rate : {8000, 16000, 32000, 48000}) { - ProduceDebugText(rate); - ApmDataDumper data_dumper(0); - EchoPathDelayEstimator estimator(&data_dumper, rate); - std::vector render(kBlockSize, 0.f); - std::vector capture(kBlockSize - 1, 0.f); - EXPECT_DEATH(estimator.EstimateDelay(render, capture), ""); - } +TEST(EchoPathDelayEstimator, WrongCaptureBlockSize) { + ApmDataDumper data_dumper(0); + EchoPathDelayEstimator estimator(&data_dumper); + std::vector render(std::vector(kBlockSize, 0.f)); + std::vector capture(std::vector(kBlockSize - 1, 0.f)); + EXPECT_DEATH(estimator.EstimateDelay(render, capture), ""); } // Verifies the check for non-null data dumper. TEST(EchoPathDelayEstimator, NullDataDumper) { - EXPECT_DEATH(EchoPathDelayEstimator(nullptr, 8000), ""); + EXPECT_DEATH(EchoPathDelayEstimator(nullptr), ""); } #endif diff --git a/webrtc/modules/audio_processing/aec3/echo_remover.cc b/webrtc/modules/audio_processing/aec3/echo_remover.cc index 2ae5525b4a..ab0b68bb16 100644 --- a/webrtc/modules/audio_processing/aec3/echo_remover.cc +++ b/webrtc/modules/audio_processing/aec3/echo_remover.cc @@ -42,8 +42,7 @@ class EchoRemoverImpl final : public EchoRemover { // TODO(peah): Add functionality. EchoRemoverImpl::EchoRemoverImpl(int sample_rate_hz) : sample_rate_hz_(sample_rate_hz) { - RTC_DCHECK(sample_rate_hz == 8000 || sample_rate_hz == 16000 || - sample_rate_hz == 32000 || sample_rate_hz == 48000); + RTC_DCHECK(ValidFullBandRate(sample_rate_hz_)); } EchoRemoverImpl::~EchoRemoverImpl() = default; diff --git a/webrtc/modules/audio_processing/aec3/matched_filter.cc b/webrtc/modules/audio_processing/aec3/matched_filter.cc new file mode 100644 index 0000000000..f187159911 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/matched_filter.cc @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "webrtc/modules/audio_processing/aec3/matched_filter.h" + +#include +#include + +#include "webrtc/modules/audio_processing/include/audio_processing.h" +#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" + +namespace webrtc { + +MatchedFilter::IndexedBuffer::IndexedBuffer(size_t size) : data(size, 0.f) { + RTC_DCHECK_EQ(0, size % kSubBlockSize); +} + +MatchedFilter::IndexedBuffer::~IndexedBuffer() = default; + +MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper, + size_t window_size_sub_blocks, + int num_matched_filters, + size_t alignment_shift_sub_blocks) + : data_dumper_(data_dumper), + filter_intra_lag_shift_(alignment_shift_sub_blocks * kSubBlockSize), + filters_(num_matched_filters, + std::vector(window_size_sub_blocks * kSubBlockSize, 0.f)), + lag_estimates_(num_matched_filters), + x_buffer_(kSubBlockSize * + (alignment_shift_sub_blocks * num_matched_filters + + window_size_sub_blocks + + 1)) { + RTC_DCHECK(data_dumper); + RTC_DCHECK_EQ(0, x_buffer_.data.size() % kSubBlockSize); + RTC_DCHECK_LT(0, window_size_sub_blocks); +} + +MatchedFilter::~MatchedFilter() = default; + +void MatchedFilter::Update(const std::array& render, + const std::array& capture) { + const std::array& x = render; + const std::array& y = capture; + + const float x2_sum_threshold = filters_[0].size() * 150.f * 150.f; + + // Insert the new subblock into x_buffer. + x_buffer_.index = (x_buffer_.index - kSubBlockSize + x_buffer_.data.size()) % + x_buffer_.data.size(); + RTC_DCHECK_LE(kSubBlockSize, x_buffer_.data.size() - x_buffer_.index); + std::copy(x.rbegin(), x.rend(), x_buffer_.data.begin() + x_buffer_.index); + + // Apply all matched filters. + size_t alignment_shift = 0; + for (size_t n = 0; n < filters_.size(); ++n) { + float error_sum = 0.f; + bool filters_updated = false; + size_t x_start_index = + (x_buffer_.index + alignment_shift + kSubBlockSize - 1) % + x_buffer_.data.size(); + + // Process for all samples in the sub-block. + for (size_t i = 0; i < kSubBlockSize; ++i) { + // As x_buffer is a circular buffer, all of the processing is split into + // two loops around the wrapping of the buffer. + const size_t loop_size_1 = + std::min(filters_[n].size(), x_buffer_.data.size() - x_start_index); + const size_t loop_size_2 = filters_[n].size() - loop_size_1; + RTC_DCHECK_EQ(filters_[n].size(), loop_size_1 + loop_size_2); + + // x * x. + float x2_sum = std::inner_product( + x_buffer_.data.begin() + x_start_index, + x_buffer_.data.begin() + x_start_index + loop_size_1, + x_buffer_.data.begin() + x_start_index, 0.f); + // Apply the matched filter as filter * x. + float s = std::inner_product(filters_[n].begin(), + filters_[n].begin() + loop_size_1, + x_buffer_.data.begin() + x_start_index, 0.f); + + if (loop_size_2 > 0) { + // Update the cumulative sum of x * x. + x2_sum = std::inner_product(x_buffer_.data.begin(), + x_buffer_.data.begin() + loop_size_2, + x_buffer_.data.begin(), x2_sum); + + // Compute the matched filter output filter * x in a cumulative manner. + s = std::inner_product(x_buffer_.data.begin(), + x_buffer_.data.begin() + loop_size_2, + filters_[n].begin() + loop_size_1, s); + } + + // Compute the matched filter error. + const float e = std::min(32767.f, std::max(-32768.f, y[i] - s)); + error_sum += e * e; + + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold) { + filters_updated = true; + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = 0.7f * e / x2_sum; + + // filter = filter + 0.7 * (y - filter * x) / x * x. + std::transform(filters_[n].begin(), filters_[n].begin() + loop_size_1, + x_buffer_.data.begin() + x_start_index, + filters_[n].begin(), + [&](float a, float b) { return a + alpha * b; }); + + if (loop_size_2 > 0) { + // filter = filter + 0.7 * (y - filter * x) / x * x. + std::transform(x_buffer_.data.begin(), + x_buffer_.data.begin() + loop_size_2, + filters_[n].begin() + loop_size_1, + filters_[n].begin() + loop_size_1, + [&](float a, float b) { return b + alpha * a; }); + } + } + + x_start_index = + x_start_index > 0 ? x_start_index - 1 : x_buffer_.data.size() - 1; + } + + // Compute anchor for the matched filter error. + const float error_sum_anchor = + std::inner_product(y.begin(), y.end(), y.begin(), 0.f); + + // Estimate the lag in the matched filter as the distance to the portion in + // the filter that contributes the most to the matched filter output. This + // is detected as the peak of the matched filter. + const size_t lag_estimate = std::distance( + filters_[n].begin(), + std::max_element( + filters_[n].begin(), filters_[n].end(), + [](float a, float b) -> bool { return a * a < b * b; })); + + // Update the lag estimates for the matched filter. + const float kMatchingFilterThreshold = 0.3f; + lag_estimates_[n] = + LagEstimate(error_sum_anchor - error_sum, + error_sum < kMatchingFilterThreshold * error_sum_anchor, + lag_estimate + alignment_shift, filters_updated); + + // TODO(peah): Remove once development of EchoCanceller3 is fully done. + RTC_DCHECK_EQ(4, filters_.size()); + switch (n) { + case 0: + data_dumper_->DumpRaw("aec3_correlator_0_h", filters_[0]); + break; + case 1: + data_dumper_->DumpRaw("aec3_correlator_1_h", filters_[1]); + break; + case 2: + data_dumper_->DumpRaw("aec3_correlator_2_h", filters_[2]); + break; + case 3: + data_dumper_->DumpRaw("aec3_correlator_3_h", filters_[3]); + break; + default: + RTC_DCHECK(false); + } + + alignment_shift += filter_intra_lag_shift_; + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/matched_filter.h b/webrtc/modules/audio_processing/aec3/matched_filter.h new file mode 100644 index 0000000000..3e09d4b971 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/matched_filter.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_ + +#include +#include +#include + +#include "webrtc/base/constructormagic.h" +#include "webrtc/base/optional.h" +#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" + +namespace webrtc { + +class ApmDataDumper; + +// Produces recursively updated cross-correlation estimates for several signal +// shifts where the intra-shift spacing is uniform. +class MatchedFilter { + public: + // Stores properties for the lag estimate corresponding to a particular signal + // shift. + struct LagEstimate { + LagEstimate() = default; + LagEstimate(float accuracy, bool reliable, size_t lag, bool updated) + : accuracy(accuracy), reliable(reliable), lag(lag), updated(updated) {} + + float accuracy = 0.f; + bool reliable = false; + size_t lag = 0; + bool updated = false; + }; + + MatchedFilter(ApmDataDumper* data_dumper, + size_t window_size_sub_blocks, + int num_matched_filters, + size_t alignment_shift_sub_blocks); + + ~MatchedFilter(); + + // Updates the correlation with the values in render and capture. + void Update(const std::array& render, + const std::array& capture); + + // Returns the current lag estimates. + rtc::ArrayView GetLagEstimates() const { + return lag_estimates_; + } + + // Returns the number of lag estimates produced using the shifted signals. + size_t NumLagEstimates() const { return filters_.size(); } + + private: + // Provides buffer with a related index. + struct IndexedBuffer { + explicit IndexedBuffer(size_t size); + ~IndexedBuffer(); + + std::vector data; + int index = 0; + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(IndexedBuffer); + }; + + ApmDataDumper* const data_dumper_; + const size_t filter_intra_lag_shift_; + std::vector> filters_; + std::vector lag_estimates_; + IndexedBuffer x_buffer_; + + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(MatchedFilter); +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_ diff --git a/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc b/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc new file mode 100644 index 0000000000..d9176efa94 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h" + +#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" + +namespace webrtc { + +MatchedFilterLagAggregator::MatchedFilterLagAggregator( + ApmDataDumper* data_dumper, + size_t num_lag_estimates) + : data_dumper_(data_dumper), lag_updates_in_a_row_(num_lag_estimates, 0) { + RTC_DCHECK(data_dumper); + RTC_DCHECK_LT(0, num_lag_estimates); +} + +MatchedFilterLagAggregator::~MatchedFilterLagAggregator() = default; + +rtc::Optional MatchedFilterLagAggregator::Aggregate( + rtc::ArrayView lag_estimates) { + RTC_DCHECK_EQ(lag_updates_in_a_row_.size(), lag_estimates.size()); + + // Count the number of lag updates in a row to ensure that only stable lags + // are taken into account. + for (size_t k = 0; k < lag_estimates.size(); ++k) { + lag_updates_in_a_row_[k] = + lag_estimates[k].updated ? lag_updates_in_a_row_[k] + 1 : 0; + } + + // If available, choose the strongest lag estimate as the best one. + int best_lag_estimate_index = -1; + for (size_t k = 0; k < lag_estimates.size(); ++k) { + if (lag_updates_in_a_row_[k] > 10 && lag_estimates[k].reliable && + (best_lag_estimate_index == -1 || + lag_estimates[k].accuracy > + lag_estimates[best_lag_estimate_index].accuracy)) { + best_lag_estimate_index = k; + } + } + + // TODO(peah): Remove this logging once all development is done. + data_dumper_->DumpRaw("aec3_echo_path_delay_estimator_best_index", + best_lag_estimate_index); + + // Require the same lag to be detected 10 times in a row before considering + // it reliable. + if (best_lag_estimate_index >= 0) { + candidate_counter_ = + (candidate_ == lag_estimates[best_lag_estimate_index].lag) + ? candidate_counter_ + 1 + : 0; + candidate_ = lag_estimates[best_lag_estimate_index].lag; + } + + return candidate_counter_ >= 10 ? rtc::Optional(candidate_) + : rtc::Optional(); +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h b/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h new file mode 100644 index 0000000000..ce8a3d67a0 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_LAG_AGGREGATOR_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_LAG_AGGREGATOR_H_ + +#include + +#include "webrtc/base/constructormagic.h" +#include "webrtc/base/optional.h" +#include "webrtc/modules/audio_processing/aec3/matched_filter.h" + +namespace webrtc { + +class ApmDataDumper; + +// Aggregates lag estimates produced by the MatchedFilter class into a single +// reliable combined lag estimate. +class MatchedFilterLagAggregator { + public: + MatchedFilterLagAggregator(ApmDataDumper* data_dumper, + size_t num_lag_estimates); + ~MatchedFilterLagAggregator(); + + // Aggregates the provided lag estimates. + rtc::Optional Aggregate( + rtc::ArrayView lag_estimates); + + private: + ApmDataDumper* const data_dumper_; + std::vector lag_updates_in_a_row_; + size_t candidate_ = 0; + size_t candidate_counter_ = 0; + + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(MatchedFilterLagAggregator); +}; +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_LAG_AGGREGATOR_H_ diff --git a/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc b/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc new file mode 100644 index 0000000000..b76116ba0b --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h" + +#include +#include +#include + +#include "webrtc/base/array_view.h" +#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { +namespace { + +void VerifyNoAggregateOutputForRepeatedLagAggregation( + size_t num_repetitions, + rtc::ArrayView lag_estimates, + MatchedFilterLagAggregator* aggregator) { + for (size_t k = 0; k < num_repetitions; ++k) { + EXPECT_FALSE(aggregator->Aggregate(lag_estimates)); + } +} + +constexpr size_t kThresholdForRequiredLagUpdatesInARow = 10; +constexpr size_t kThresholdForRequiredIdenticalLagAggregates = 10; + +} // namespace + +// Verifies that the most accurate lag estimate is chosen. +TEST(MatchedFilterLagAggregator, MostAccurateLagChosen) { + constexpr size_t kArtificialLag1 = 5; + constexpr size_t kArtificialLag2 = 10; + ApmDataDumper data_dumper(0); + std::vector lag_estimates(2); + MatchedFilterLagAggregator aggregator(&data_dumper, lag_estimates.size()); + lag_estimates[0] = + MatchedFilter::LagEstimate(1.f, true, kArtificialLag1, true); + lag_estimates[1] = + MatchedFilter::LagEstimate(0.5f, true, kArtificialLag2, true); + + VerifyNoAggregateOutputForRepeatedLagAggregation( + kThresholdForRequiredLagUpdatesInARow + + kThresholdForRequiredIdenticalLagAggregates, + lag_estimates, &aggregator); + rtc::Optional aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag1, *aggregated_lag); + + lag_estimates[0] = + MatchedFilter::LagEstimate(0.5f, true, kArtificialLag1, true); + lag_estimates[1] = + MatchedFilter::LagEstimate(1.f, true, kArtificialLag2, true); + + VerifyNoAggregateOutputForRepeatedLagAggregation( + kThresholdForRequiredIdenticalLagAggregates, lag_estimates, &aggregator); + aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag2, *aggregated_lag); +} + +// Verifies that varying lag estimates causes lag estimates to not be deemed +// reliable. +TEST(MatchedFilterLagAggregator, + LagEstimateInvarianceRequiredForAggregatedLag) { + constexpr size_t kArtificialLag1 = 5; + constexpr size_t kArtificialLag2 = 10; + ApmDataDumper data_dumper(0); + std::vector lag_estimates(1); + MatchedFilterLagAggregator aggregator(&data_dumper, lag_estimates.size()); + lag_estimates[0] = + MatchedFilter::LagEstimate(1.f, true, kArtificialLag1, true); + VerifyNoAggregateOutputForRepeatedLagAggregation( + kThresholdForRequiredLagUpdatesInARow + + kThresholdForRequiredIdenticalLagAggregates, + lag_estimates, &aggregator); + rtc::Optional aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag1, *aggregated_lag); + + lag_estimates[0] = + MatchedFilter::LagEstimate(1.f, true, kArtificialLag2, true); + + VerifyNoAggregateOutputForRepeatedLagAggregation( + kThresholdForRequiredIdenticalLagAggregates, lag_estimates, &aggregator); + aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag2, *aggregated_lag); +} + +// Verifies that lag estimate updates are required to produce an updated lag +// aggregate. +TEST(MatchedFilterLagAggregator, LagEstimateUpdatesRequiredForAggregatedLag) { + constexpr size_t kArtificialLag1 = 5; + constexpr size_t kArtificialLag2 = 10; + ApmDataDumper data_dumper(0); + std::vector lag_estimates(1); + MatchedFilterLagAggregator aggregator(&data_dumper, lag_estimates.size()); + lag_estimates[0] = + MatchedFilter::LagEstimate(1.f, true, kArtificialLag1, true); + VerifyNoAggregateOutputForRepeatedLagAggregation( + kThresholdForRequiredLagUpdatesInARow + + kThresholdForRequiredIdenticalLagAggregates, + lag_estimates, &aggregator); + rtc::Optional aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag1, *aggregated_lag); + + lag_estimates[0] = + MatchedFilter::LagEstimate(1.f, true, kArtificialLag2, false); + + for (size_t k = 0; k < kThresholdForRequiredLagUpdatesInARow + + kThresholdForRequiredIdenticalLagAggregates + 1; + ++k) { + aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag1, *aggregated_lag); + } + + lag_estimates[0] = + MatchedFilter::LagEstimate(1.f, true, kArtificialLag2, true); + for (size_t k = 0; k < kThresholdForRequiredLagUpdatesInARow; ++k) { + aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag1, *aggregated_lag); + } + + VerifyNoAggregateOutputForRepeatedLagAggregation( + kThresholdForRequiredIdenticalLagAggregates, lag_estimates, &aggregator); + + aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag2, *aggregated_lag); +} + +// Verifies that an aggregated lag is persistent if the lag estimates do not +// change and that an aggregated lag is not produced without gaining lag +// estimate confidence. +TEST(MatchedFilterLagAggregator, PersistentAggregatedLag) { + constexpr size_t kArtificialLag = 5; + ApmDataDumper data_dumper(0); + std::vector lag_estimates(1); + MatchedFilterLagAggregator aggregator(&data_dumper, lag_estimates.size()); + lag_estimates[0] = + MatchedFilter::LagEstimate(1.f, true, kArtificialLag, true); + VerifyNoAggregateOutputForRepeatedLagAggregation( + kThresholdForRequiredLagUpdatesInARow + + kThresholdForRequiredIdenticalLagAggregates, + lag_estimates, &aggregator); + rtc::Optional aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag, *aggregated_lag); + + aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag, *aggregated_lag); +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies the check for correct number of lag estimates. +TEST(MatchedFilterLagAggregator, IncorrectNumberOfLagEstimates) { + ApmDataDumper data_dumper(0); + MatchedFilterLagAggregator aggregator(&data_dumper, 1); + std::vector lag_estimates(2); + + EXPECT_DEATH(aggregator.Aggregate(lag_estimates), ""); +} + +// Verifies the check for non-zero number of lag estimates. +TEST(MatchedFilterLagAggregator, NonZeroLagEstimates) { + ApmDataDumper data_dumper(0); + EXPECT_DEATH(MatchedFilterLagAggregator(&data_dumper, 0), ""); +} + +// Verifies the check for non-null data dumper. +TEST(MatchedFilterLagAggregator, NullDataDumper) { + EXPECT_DEATH(MatchedFilterLagAggregator(nullptr, 1), ""); +} + +#endif + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc b/webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc new file mode 100644 index 0000000000..993ebc8b92 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/matched_filter.h" + +#include +#include +#include + +#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" +#include "webrtc/modules/audio_processing/test/echo_canceller_test_tools.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { +namespace { + +std::string ProduceDebugText(size_t delay) { + std::ostringstream ss; + ss << "Delay: " << delay; + return ss.str(); +} + +constexpr size_t kWindowSizeSubBlocks = 32; +constexpr size_t kAlignmentShiftSubBlocks = kWindowSizeSubBlocks * 3 / 4; +constexpr size_t kNumMatchedFilters = 4; + +} // namespace + +// Verifies that the matched filter produces proper lag estimates for +// artificially +// delayed signals. +TEST(MatchedFilter, LagEstimation) { + Random random_generator(42U); + std::array render; + std::array capture; + render.fill(0.f); + capture.fill(0.f); + ApmDataDumper data_dumper(0); + for (size_t delay_samples : {0, 64, 150, 200, 800, 1000}) { + SCOPED_TRACE(ProduceDebugText(delay_samples)); + DelayBuffer signal_delay_buffer(delay_samples); + MatchedFilter filter(&data_dumper, kWindowSizeSubBlocks, kNumMatchedFilters, + kAlignmentShiftSubBlocks); + + // Analyze the correlation between render and capture. + for (size_t k = 0; k < (100 + delay_samples / kSubBlockSize); ++k) { + RandomizeSampleVector(&random_generator, render); + signal_delay_buffer.Delay(render, capture); + filter.Update(render, capture); + } + + // Obtain the lag estimates. + auto lag_estimates = filter.GetLagEstimates(); + + // Find which lag estimate should be the most accurate. + rtc::Optional expected_most_accurate_lag_estimate; + size_t alignment_shift_sub_blocks = 0; + for (size_t k = 0; k < kNumMatchedFilters; ++k) { + if ((alignment_shift_sub_blocks + kWindowSizeSubBlocks / 2) * + kSubBlockSize > + delay_samples) { + expected_most_accurate_lag_estimate = rtc::Optional(k); + break; + } + alignment_shift_sub_blocks += kAlignmentShiftSubBlocks; + } + ASSERT_TRUE(expected_most_accurate_lag_estimate); + + // Verify that the expected most accurate lag estimate is the most accurate + // estimate. + for (size_t k = 0; k < kNumMatchedFilters; ++k) { + if (k != *expected_most_accurate_lag_estimate) { + EXPECT_GT(lag_estimates[*expected_most_accurate_lag_estimate].accuracy, + lag_estimates[k].accuracy); + } + } + + // Verify that all lag estimates are updated as expected for signals + // containing strong noise. + for (auto& le : lag_estimates) { + EXPECT_TRUE(le.updated); + } + + // Verify that the expected most accurate lag estimate is reliable. + EXPECT_TRUE(lag_estimates[*expected_most_accurate_lag_estimate].reliable); + + // Verify that the expected most accurate lag estimate is correct. + EXPECT_EQ(delay_samples, + lag_estimates[*expected_most_accurate_lag_estimate].lag); + } +} + +// Verifies that the matched filter does not produce reliable and accurate +// estimates for uncorrelated render and capture signals. +TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) { + Random random_generator(42U); + std::array render; + std::array capture; + render.fill(0.f); + capture.fill(0.f); + ApmDataDumper data_dumper(0); + MatchedFilter filter(&data_dumper, kWindowSizeSubBlocks, kNumMatchedFilters, + kAlignmentShiftSubBlocks); + + // Analyze the correlation between render and capture. + for (size_t k = 0; k < 100; ++k) { + RandomizeSampleVector(&random_generator, render); + RandomizeSampleVector(&random_generator, capture); + filter.Update(render, capture); + } + + // Obtain the lag estimates. + auto lag_estimates = filter.GetLagEstimates(); + EXPECT_EQ(kNumMatchedFilters, lag_estimates.size()); + + // Verify that no lag estimates are reliable. + for (auto& le : lag_estimates) { + EXPECT_FALSE(le.reliable); + } +} + +// Verifies that the matched filter does not produce updated lag estimates for +// render signals of low level. +TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) { + Random random_generator(42U); + std::array render; + std::array capture; + render.fill(0.f); + capture.fill(0.f); + ApmDataDumper data_dumper(0); + MatchedFilter filter(&data_dumper, kWindowSizeSubBlocks, kNumMatchedFilters, + kAlignmentShiftSubBlocks); + + // Analyze the correlation between render and capture. + for (size_t k = 0; k < 100; ++k) { + RandomizeSampleVector(&random_generator, render); + for (auto& render_k : render) { + render_k *= 149.f / 32767.f; + } + std::copy(render.begin(), render.end(), capture.begin()); + filter.Update(render, capture); + } + + // Obtain the lag estimates. + auto lag_estimates = filter.GetLagEstimates(); + EXPECT_EQ(kNumMatchedFilters, lag_estimates.size()); + + // Verify that no lag estimates are updated and that no lag estimates are + // reliable. + for (auto& le : lag_estimates) { + EXPECT_FALSE(le.updated); + EXPECT_FALSE(le.reliable); + } +} + +// Verifies that the correct number of lag estimates are produced for a certain +// number of alignment shifts. +TEST(MatchedFilter, NumberOfLagEstimates) { + ApmDataDumper data_dumper(0); + for (size_t num_matched_filters = 0; num_matched_filters < 10; + ++num_matched_filters) { + MatchedFilter filter(&data_dumper, 32, num_matched_filters, 1); + EXPECT_EQ(num_matched_filters, filter.GetLagEstimates().size()); + } +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies the check for non-zero windows size. +TEST(MatchedFilter, ZeroWindowSize) { + ApmDataDumper data_dumper(0); + EXPECT_DEATH(MatchedFilter(&data_dumper, 0, 1, 1), ""); +} + +// Verifies the check for non-null data dumper. +TEST(MatchedFilter, NullDataDumper) { + EXPECT_DEATH(MatchedFilter(nullptr, 1, 1, 1), ""); +} + +#endif + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/render_delay_controller.cc b/webrtc/modules/audio_processing/aec3/render_delay_controller.cc index e18701e5ea..981b744eef 100644 --- a/webrtc/modules/audio_processing/aec3/render_delay_controller.cc +++ b/webrtc/modules/audio_processing/aec3/render_delay_controller.cc @@ -115,9 +115,8 @@ RenderDelayControllerImpl::RenderDelayControllerImpl( max_delay_(render_delay_buffer.MaxDelay()), delay_(render_delay_buffer.Delay()), render_buffer_(render_delay_buffer.MaxApiJitter() + 1), - delay_estimator_(data_dumper_.get(), sample_rate_hz) { - RTC_DCHECK(sample_rate_hz == 8000 || sample_rate_hz == 16000 || - sample_rate_hz == 32000 || sample_rate_hz == 48000); + delay_estimator_(data_dumper_.get()) { + RTC_DCHECK(ValidFullBandRate(sample_rate_hz)); } RenderDelayControllerImpl::~RenderDelayControllerImpl() = default; diff --git a/webrtc/modules/audio_processing/aec3/render_delay_controller_unittest.cc b/webrtc/modules/audio_processing/aec3/render_delay_controller_unittest.cc index 145f3a3042..9d382adb0f 100644 --- a/webrtc/modules/audio_processing/aec3/render_delay_controller_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/render_delay_controller_unittest.cc @@ -105,8 +105,7 @@ TEST(RenderDelayController, BasicApiCalls) { // Verifies that the RenderDelayController is able to align the signals for // simple timeshifts between the signals. -// TODO(peah): Activate the unittest once the required code has been landed. -TEST(RenderDelayController, DISABLED_Alignment) { +TEST(RenderDelayController, Alignment) { Random random_generator(42U); std::vector render_block(kBlockSize, 0.f); std::vector capture_block(kBlockSize, 0.f); @@ -148,8 +147,7 @@ TEST(RenderDelayController, DISABLED_Alignment) { // Verifies that the RenderDelayController is able to align the signals for // simple timeshifts between the signals when there is jitter in the API calls. -// TODO(peah): Activate the unittest once the required code has been landed. -TEST(RenderDelayController, DISABLED_AlignmentWithJitter) { +TEST(RenderDelayController, AlignmentWithJitter) { Random random_generator(42U); std::vector render_block(kBlockSize, 0.f); std::vector capture_block(kBlockSize, 0.f);