diff --git a/webrtc/modules/audio_processing/BUILD.gn b/webrtc/modules/audio_processing/BUILD.gn index dfedbee11c..2684a782c3 100644 --- a/webrtc/modules/audio_processing/BUILD.gn +++ b/webrtc/modules/audio_processing/BUILD.gn @@ -33,6 +33,8 @@ rtc_static_library("audio_processing") { "aec3/block_processor.h", "aec3/cascaded_biquad_filter.cc", "aec3/cascaded_biquad_filter.h", + "aec3/decimator_by_4.cc", + "aec3/decimator_by_4.h", "aec3/echo_canceller3.cc", "aec3/echo_canceller3.h", "aec3/echo_path_delay_estimator.cc", @@ -42,6 +44,10 @@ rtc_static_library("audio_processing") { "aec3/echo_remover.h", "aec3/frame_blocker.cc", "aec3/frame_blocker.h", + "aec3/matched_filter.cc", + "aec3/matched_filter.h", + "aec3/matched_filter_lag_aggregator.cc", + "aec3/matched_filter_lag_aggregator.h", "aec3/render_delay_buffer.cc", "aec3/render_delay_buffer.h", "aec3/render_delay_controller.cc", @@ -519,10 +525,13 @@ if (rtc_include_tests) { "aec3/block_framer_unittest.cc", "aec3/block_processor_unittest.cc", "aec3/cascaded_biquad_filter_unittest.cc", + "aec3/decimator_by_4_unittest.cc", "aec3/echo_canceller3_unittest.cc", "aec3/echo_path_delay_estimator_unittest.cc", "aec3/echo_remover_unittest.cc", "aec3/frame_blocker_unittest.cc", + "aec3/matched_filter_lag_aggregator_unittest.cc", + "aec3/matched_filter_unittest.cc", "aec3/mock/mock_block_processor.h", "aec3/mock/mock_echo_remover.h", "aec3/mock/mock_render_delay_buffer.h", diff --git a/webrtc/modules/audio_processing/aec3/aec3_constants.h b/webrtc/modules/audio_processing/aec3/aec3_constants.h index 946e50c786..054b0d8afd 100644 --- a/webrtc/modules/audio_processing/aec3/aec3_constants.h +++ b/webrtc/modules/audio_processing/aec3/aec3_constants.h @@ -34,6 +34,11 @@ constexpr int LowestBandRate(int sample_rate_hz) { return sample_rate_hz == 8000 ? sample_rate_hz : 16000; } +constexpr bool ValidFullBandRate(int sample_rate_hz) { + return sample_rate_hz == 8000 || sample_rate_hz == 16000 || + sample_rate_hz == 32000 || sample_rate_hz == 48000; +} + static_assert(1 == NumBandsForRate(8000), "Number of bands for 8 kHz"); static_assert(1 == NumBandsForRate(16000), "Number of bands for 16 kHz"); static_assert(2 == NumBandsForRate(32000), "Number of bands for 32 kHz"); @@ -47,6 +52,17 @@ static_assert(16000 == LowestBandRate(32000), static_assert(16000 == LowestBandRate(48000), "Sample rate of band 0 for 48 kHz"); +static_assert(ValidFullBandRate(8000), + "Test that 8 kHz is a valid sample rate"); +static_assert(ValidFullBandRate(16000), + "Test that 16 kHz is a valid sample rate"); +static_assert(ValidFullBandRate(32000), + "Test that 32 kHz is a valid sample rate"); +static_assert(ValidFullBandRate(48000), + "Test that 48 kHz is a valid sample rate"); +static_assert(!ValidFullBandRate(8001), + "Test that 8001 Hz is not a valid sample rate"); + } // namespace webrtc #endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_AEC3_CONSTANTS_H_ diff --git a/webrtc/modules/audio_processing/aec3/block_processor.cc b/webrtc/modules/audio_processing/aec3/block_processor.cc index 1f50a84854..550a21073f 100644 --- a/webrtc/modules/audio_processing/aec3/block_processor.cc +++ b/webrtc/modules/audio_processing/aec3/block_processor.cc @@ -61,7 +61,9 @@ BlockProcessorImpl::BlockProcessorImpl( sample_rate_hz_(sample_rate_hz), render_buffer_(std::move(render_buffer)), delay_controller_(std::move(delay_controller)), - echo_remover_(std::move(echo_remover)) {} + echo_remover_(std::move(echo_remover)) { + RTC_DCHECK(ValidFullBandRate(sample_rate_hz_)); +} BlockProcessorImpl::~BlockProcessorImpl() = default; diff --git a/webrtc/modules/audio_processing/aec3/block_processor_unittest.cc b/webrtc/modules/audio_processing/aec3/block_processor_unittest.cc index b5d6a1432f..ac3af6afb1 100644 --- a/webrtc/modules/audio_processing/aec3/block_processor_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/block_processor_unittest.cc @@ -252,6 +252,12 @@ TEST(BlockProcessor, NullBufferRenderParameter) { ""); } +// Verifies the check for correct sample rate. +TEST(BlockProcessor, WrongSampleRate) { + EXPECT_DEATH(std::unique_ptr(BlockProcessor::Create(8001)), + ""); +} + #endif } // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/decimator_by_4.cc b/webrtc/modules/audio_processing/aec3/decimator_by_4.cc new file mode 100644 index 0000000000..3f4c858ec3 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/decimator_by_4.cc @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "webrtc/modules/audio_processing/aec3/decimator_by_4.h" + +#include "webrtc/base/checks.h" + +namespace webrtc { +namespace { + +// [B,A] = butter(2,1500/16000) which are the same as [B,A] = +// butter(2,750/8000). +const CascadedBiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients = { + {0.0179f, 0.0357f, 0.0179f}, + {-1.5879f, 0.6594f}}; + +} // namespace + +DecimatorBy4::DecimatorBy4() + : low_pass_filter_(kLowPassFilterCoefficients, 3) {} + +void DecimatorBy4::Decimate(rtc::ArrayView in, + std::array* out) { + RTC_DCHECK_EQ(kBlockSize, in.size()); + RTC_DCHECK(out); + std::array x; + + // Limit the frequency content of the signal to avoid aliasing. + low_pass_filter_.Process(in, x); + + // Downsample the signal. + for (size_t j = 0, k = 0; j < out->size(); ++j, k += 4) { + RTC_DCHECK_GT(kBlockSize, k); + (*out)[j] = x[k]; + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/decimator_by_4.h b/webrtc/modules/audio_processing/aec3/decimator_by_4.h new file mode 100644 index 0000000000..076c1688c8 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/decimator_by_4.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_DECIMATOR_BY_4_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_DECIMATOR_BY_4_H_ + +#include + +#include "webrtc/base/array_view.h" +#include "webrtc/base/constructormagic.h" +#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/cascaded_biquad_filter.h" + +namespace webrtc { + +// Provides functionality for decimating a signal by 4. +class DecimatorBy4 { + public: + DecimatorBy4(); + + // Downsamples the signal. + void Decimate(rtc::ArrayView in, + std::array* out); + + private: + CascadedBiQuadFilter low_pass_filter_; + + RTC_DISALLOW_COPY_AND_ASSIGN(DecimatorBy4); +}; +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_DECIMATOR_BY_4_H_ diff --git a/webrtc/modules/audio_processing/aec3/decimator_by_4_unittest.cc b/webrtc/modules/audio_processing/aec3/decimator_by_4_unittest.cc new file mode 100644 index 0000000000..a7699ba64f --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/decimator_by_4_unittest.cc @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/decimator_by_4.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { + +namespace { + +std::string ProduceDebugText(int sample_rate_hz) { + std::ostringstream ss; + ss << "Sample rate: " << sample_rate_hz; + return ss.str(); +} + +constexpr float kPi = 3.141592f; +constexpr size_t kNumStartupBlocks = 50; +constexpr size_t kNumBlocks = 1000; + +void ProduceDecimatedSinusoidalOutputPower(int sample_rate_hz, + float sinusoidal_frequency_hz, + float* input_power, + float* output_power) { + float input[kBlockSize * kNumBlocks]; + + // Produce a sinusoid of the specified frequency. + for (size_t k = 0; k < kBlockSize * kNumBlocks; ++k) { + input[k] = + 32767.f * sin(2.f * kPi * sinusoidal_frequency_hz * k / sample_rate_hz); + } + + DecimatorBy4 decimator; + std::array output; + + for (size_t k = 0; k < kNumBlocks; ++k) { + std::array sub_block; + + decimator.Decimate( + rtc::ArrayView(&input[k * kBlockSize], kBlockSize), + &sub_block); + + std::copy(sub_block.begin(), sub_block.end(), + output.begin() + k * kSubBlockSize); + } + + ASSERT_GT(kNumBlocks, kNumStartupBlocks); + rtc::ArrayView input_to_evaluate( + &input[kNumStartupBlocks * kBlockSize], + (kNumBlocks - kNumStartupBlocks) * kBlockSize); + rtc::ArrayView output_to_evaluate( + &output[kNumStartupBlocks * kSubBlockSize], + (kNumBlocks - kNumStartupBlocks) * kSubBlockSize); + *input_power = + std::inner_product(input_to_evaluate.begin(), input_to_evaluate.end(), + input_to_evaluate.begin(), 0.f) / + input_to_evaluate.size(); + *output_power = + std::inner_product(output_to_evaluate.begin(), output_to_evaluate.end(), + output_to_evaluate.begin(), 0.f) / + output_to_evaluate.size(); +} + +} // namespace + +// Verifies that there is little aliasing from upper frequencies in the +// downsampling. +TEST(DecimatorBy4, NoLeakageFromUpperFrequencies) { + float input_power; + float output_power; + for (auto rate : {8000, 16000, 32000, 48000}) { + ProduceDebugText(rate); + ProduceDecimatedSinusoidalOutputPower(rate, 3.f / 8.f * rate, &input_power, + &output_power); + EXPECT_GT(0.0001f * input_power, output_power); + } +} + +// Verifies that the impact of low-frequency content is small during the +// downsampling. +TEST(DecimatorBy4, NoImpactOnLowerFrequencies) { + float input_power; + float output_power; + for (auto rate : {8000, 16000, 32000, 48000}) { + ProduceDebugText(rate); + ProduceDecimatedSinusoidalOutputPower(rate, 200.f, &input_power, + &output_power); + EXPECT_LT(0.7f * input_power, output_power); + } +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) +// Verifies the check for the input size. +TEST(DecimatorBy4, WrongInputSize) { + DecimatorBy4 decimator; + std::vector x(std::vector(kBlockSize - 1, 0.f)); + std::array x_downsampled; + EXPECT_DEATH(decimator.Decimate(x, &x_downsampled), ""); +} + +// Verifies the check for non-null output parameter. +TEST(DecimatorBy4, NullOutput) { + DecimatorBy4 decimator; + std::vector x(std::vector(kBlockSize, 0.f)); + EXPECT_DEATH(decimator.Decimate(x, nullptr), ""); +} + +#endif + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/echo_canceller3.cc b/webrtc/modules/audio_processing/aec3/echo_canceller3.cc index ec7b55a215..a4b796d969 100644 --- a/webrtc/modules/audio_processing/aec3/echo_canceller3.cc +++ b/webrtc/modules/audio_processing/aec3/echo_canceller3.cc @@ -227,6 +227,8 @@ EchoCanceller3::EchoCanceller3(int sample_rate_hz, std::vector(frame_length_, 0.f)), block_(num_bands_, std::vector(kBlockSize, 0.f)), sub_frame_view_(num_bands_) { + RTC_DCHECK(ValidFullBandRate(sample_rate_hz_)); + std::unique_ptr render_highpass_filter; if (use_highpass_filter) { render_highpass_filter.reset(new CascadedBiQuadFilter( diff --git a/webrtc/modules/audio_processing/aec3/echo_canceller3_unittest.cc b/webrtc/modules/audio_processing/aec3/echo_canceller3_unittest.cc index afe429d1d2..8ccaa51e52 100644 --- a/webrtc/modules/audio_processing/aec3/echo_canceller3_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/echo_canceller3_unittest.cc @@ -719,6 +719,12 @@ TEST(EchoCanceller3InputCheck, NullCaptureProcessingParameter) { EXPECT_DEATH(EchoCanceller3(8000, false).ProcessCapture(nullptr, false), ""); } +// Verifies the check for correct sample rate. +TEST(EchoCanceller3InputCheck, WrongSampleRate) { + ApmDataDumper data_dumper(0); + EXPECT_DEATH(EchoCanceller3(8001, false), ""); +} + #endif } // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.cc b/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.cc index 3ae8f879ef..539832df03 100644 --- a/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.cc +++ b/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.cc @@ -9,29 +9,71 @@ */ #include "webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.h" +#include +#include + #include "webrtc/base/checks.h" #include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/include/audio_processing.h" #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" namespace webrtc { -// TODO(peah): Add functionality. -EchoPathDelayEstimator::EchoPathDelayEstimator(ApmDataDumper* data_dumper, - int sample_rate_hz) { +namespace { + +constexpr size_t kNumMatchedFilters = 4; +constexpr size_t kMatchedFilterWindowSizeSubBlocks = 32; +constexpr size_t kMatchedFilterAlignmentShiftSizeSubBlocks = + kMatchedFilterWindowSizeSubBlocks * 3 / 4; + +constexpr int kDownSamplingFactor = 4; +} // namespace + +EchoPathDelayEstimator::EchoPathDelayEstimator(ApmDataDumper* data_dumper) + : data_dumper_(data_dumper), + matched_filter_(data_dumper_, + kMatchedFilterWindowSizeSubBlocks, + kNumMatchedFilters, + kMatchedFilterAlignmentShiftSizeSubBlocks), + matched_filter_lag_aggregator_(data_dumper_, + matched_filter_.NumLagEstimates()) { RTC_DCHECK(data_dumper); - RTC_DCHECK(sample_rate_hz == 8000 || sample_rate_hz == 16000 || - sample_rate_hz == 32000 || sample_rate_hz == 48000); } EchoPathDelayEstimator::~EchoPathDelayEstimator() = default; -// TODO(peah): Add functionality. rtc::Optional EchoPathDelayEstimator::EstimateDelay( rtc::ArrayView render, rtc::ArrayView capture) { - RTC_DCHECK_EQ(render.size(), kBlockSize); - RTC_DCHECK_EQ(capture.size(), kBlockSize); - return rtc::Optional(); + RTC_DCHECK_EQ(kBlockSize, capture.size()); + RTC_DCHECK_EQ(render.size(), capture.size()); + + std::array downsampled_render; + std::array downsampled_capture; + + render_decimator_.Decimate(render, &downsampled_render); + capture_decimator_.Decimate(capture, &downsampled_capture); + + matched_filter_.Update(downsampled_render, downsampled_capture); + + rtc::Optional aggregated_matched_filter_lag = + matched_filter_lag_aggregator_.Aggregate( + matched_filter_.GetLagEstimates()); + + // TODO(peah): Move this logging outside of this class once EchoCanceller3 + // development is done. + data_dumper_->DumpRaw("aec3_echo_path_delay_estimator_delay", + aggregated_matched_filter_lag + ? static_cast(*aggregated_matched_filter_lag * + kDownSamplingFactor) + : -1); + + // Return the detected delay in samples as the aggregated matched filter lag + // compensated by the down sampling factor for the signal being correlated. + return aggregated_matched_filter_lag + ? rtc::Optional(*aggregated_matched_filter_lag * + kDownSamplingFactor) + : rtc::Optional(); } } // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.h b/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.h index ae8ab4318e..bbe9a7c32a 100644 --- a/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.h +++ b/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.h @@ -15,19 +15,31 @@ #include "webrtc/base/constructormagic.h" #include "webrtc/base/optional.h" +#include "webrtc/modules/audio_processing/aec3/matched_filter.h" +#include "webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h" +#include "webrtc/modules/audio_processing/aec3/decimator_by_4.h" namespace webrtc { class ApmDataDumper; +// Estimates the delay of the echo path. class EchoPathDelayEstimator { public: - EchoPathDelayEstimator(ApmDataDumper* data_dumper, int sample_rate_hz); + explicit EchoPathDelayEstimator(ApmDataDumper* data_dumper); ~EchoPathDelayEstimator(); + + // Produce a delay estimate if such is avaliable. rtc::Optional EstimateDelay(rtc::ArrayView render, rtc::ArrayView capture); private: + ApmDataDumper* const data_dumper_; + DecimatorBy4 render_decimator_; + DecimatorBy4 capture_decimator_; + MatchedFilter matched_filter_; + MatchedFilterLagAggregator matched_filter_lag_aggregator_; + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(EchoPathDelayEstimator); }; } // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc b/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc index c3146ed7b8..ba9ff23540 100644 --- a/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc @@ -10,19 +10,22 @@ #include "webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.h" +#include #include #include +#include "webrtc/base/random.h" #include "webrtc/modules/audio_processing/aec3/aec3_constants.h" #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" +#include "webrtc/modules/audio_processing/test/echo_canceller_test_tools.h" #include "webrtc/test/gtest.h" namespace webrtc { namespace { -std::string ProduceDebugText(int sample_rate_hz) { +std::string ProduceDebugText(size_t delay) { std::ostringstream ss; - ss << "Sample rate: " << sample_rate_hz; + ss << "Delay: " << delay; return ss.str(); } @@ -30,57 +33,120 @@ std::string ProduceDebugText(int sample_rate_hz) { // Verifies that the basic API calls work. TEST(EchoPathDelayEstimator, BasicApiCalls) { - for (auto rate : {8000, 16000, 32000, 48000}) { - ProduceDebugText(rate); - ApmDataDumper data_dumper(0); - EchoPathDelayEstimator estimator(&data_dumper, rate); - std::vector render(kBlockSize, 0.f); - std::vector capture(kBlockSize, 0.f); - for (size_t k = 0; k < 100; ++k) { - estimator.EstimateDelay(render, capture); + ApmDataDumper data_dumper(0); + EchoPathDelayEstimator estimator(&data_dumper); + std::vector render(kBlockSize, 0.f); + std::vector capture(kBlockSize, 0.f); + for (size_t k = 0; k < 100; ++k) { + estimator.EstimateDelay(render, capture); + } +} + +// Verifies that the delay estimator produces correct delay for artificially +// delayed signals. +TEST(EchoPathDelayEstimator, DelayEstimation) { + Random random_generator(42U); + std::vector render(kBlockSize, 0.f); + std::vector capture(kBlockSize, 0.f); + ApmDataDumper data_dumper(0); + for (size_t delay_samples : {0, 64, 150, 200, 800, 4000}) { + SCOPED_TRACE(ProduceDebugText(delay_samples)); + DelayBuffer signal_delay_buffer(delay_samples); + EchoPathDelayEstimator estimator(&data_dumper); + + rtc::Optional estimated_delay_samples; + for (size_t k = 0; k < (100 + delay_samples / kBlockSize); ++k) { + RandomizeSampleVector(&random_generator, render); + signal_delay_buffer.Delay(render, capture); + estimated_delay_samples = estimator.EstimateDelay(render, capture); } + if (estimated_delay_samples) { + // Due to the internal down-sampling by 4 done inside the delay estimator + // the estimated delay cannot be expected to be closer than 4 samples to + // the true delay. + EXPECT_NEAR(delay_samples, *estimated_delay_samples, 4); + } else { + ADD_FAILURE(); + } + } +} + +// Verifies that the delay estimator does not produce delay estimates too +// quickly. +TEST(EchoPathDelayEstimator, NoInitialDelayestimates) { + Random random_generator(42U); + std::vector render(kBlockSize, 0.f); + std::vector capture(kBlockSize, 0.f); + ApmDataDumper data_dumper(0); + + EchoPathDelayEstimator estimator(&data_dumper); + for (size_t k = 0; k < 19; ++k) { + RandomizeSampleVector(&random_generator, render); + std::copy(render.begin(), render.end(), capture.begin()); + EXPECT_FALSE(estimator.EstimateDelay(render, capture)); + } +} + +// Verifies that the delay estimator does not produce delay estimates for render +// signals of low level. +TEST(EchoPathDelayEstimator, NoDelayEstimatesForLowLevelRenderSignals) { + Random random_generator(42U); + std::vector render(kBlockSize, 0.f); + std::vector capture(kBlockSize, 0.f); + ApmDataDumper data_dumper(0); + EchoPathDelayEstimator estimator(&data_dumper); + for (size_t k = 0; k < 100; ++k) { + RandomizeSampleVector(&random_generator, render); + for (auto& render_k : render) { + render_k *= 100.f / 32767.f; + } + std::copy(render.begin(), render.end(), capture.begin()); + EXPECT_FALSE(estimator.EstimateDelay(render, capture)); + } +} + +// Verifies that the delay estimator does not produce delay estimates for +// uncorrelated signals. +TEST(EchoPathDelayEstimator, NoDelayEstimatesForUncorrelatedSignals) { + Random random_generator(42U); + std::vector render(kBlockSize, 0.f); + std::vector capture(kBlockSize, 0.f); + ApmDataDumper data_dumper(0); + EchoPathDelayEstimator estimator(&data_dumper); + for (size_t k = 0; k < 100; ++k) { + RandomizeSampleVector(&random_generator, render); + RandomizeSampleVector(&random_generator, capture); + EXPECT_FALSE(estimator.EstimateDelay(render, capture)); } } #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) -// Verifies the check for correct sample rate. -TEST(EchoPathDelayEstimator, WrongSampleRate) { - ApmDataDumper data_dumper(0); - EXPECT_DEATH(EchoPathDelayEstimator remover(&data_dumper, 8001), ""); -} - // Verifies the check for the render blocksize. // TODO(peah): Re-enable the test once the issue with memory leaks during DEATH // tests on test bots has been fixed. TEST(EchoPathDelayEstimator, DISABLED_WrongRenderBlockSize) { - for (auto rate : {8000, 16000, 32000, 48000}) { - ProduceDebugText(rate); - ApmDataDumper data_dumper(0); - EchoPathDelayEstimator estimator(&data_dumper, rate); - std::vector render(kBlockSize - 1, 0.f); - std::vector capture(kBlockSize, 0.f); - EXPECT_DEATH(estimator.EstimateDelay(render, capture), ""); - } + ApmDataDumper data_dumper(0); + EchoPathDelayEstimator estimator(&data_dumper); + std::vector render(std::vector(kBlockSize - 1, 0.f)); + std::vector capture(std::vector(kBlockSize, 0.f)); + EXPECT_DEATH(estimator.EstimateDelay(render, capture), ""); } // Verifies the check for the capture blocksize. // TODO(peah): Re-enable the test once the issue with memory leaks during DEATH // tests on test bots has been fixed. -TEST(EchoPathDelayEstimator, DISABLED_WrongCaptureBlockSize) { - for (auto rate : {8000, 16000, 32000, 48000}) { - ProduceDebugText(rate); - ApmDataDumper data_dumper(0); - EchoPathDelayEstimator estimator(&data_dumper, rate); - std::vector render(kBlockSize, 0.f); - std::vector capture(kBlockSize - 1, 0.f); - EXPECT_DEATH(estimator.EstimateDelay(render, capture), ""); - } +TEST(EchoPathDelayEstimator, WrongCaptureBlockSize) { + ApmDataDumper data_dumper(0); + EchoPathDelayEstimator estimator(&data_dumper); + std::vector render(std::vector(kBlockSize, 0.f)); + std::vector capture(std::vector(kBlockSize - 1, 0.f)); + EXPECT_DEATH(estimator.EstimateDelay(render, capture), ""); } // Verifies the check for non-null data dumper. TEST(EchoPathDelayEstimator, NullDataDumper) { - EXPECT_DEATH(EchoPathDelayEstimator(nullptr, 8000), ""); + EXPECT_DEATH(EchoPathDelayEstimator(nullptr), ""); } #endif diff --git a/webrtc/modules/audio_processing/aec3/echo_remover.cc b/webrtc/modules/audio_processing/aec3/echo_remover.cc index 2ae5525b4a..ab0b68bb16 100644 --- a/webrtc/modules/audio_processing/aec3/echo_remover.cc +++ b/webrtc/modules/audio_processing/aec3/echo_remover.cc @@ -42,8 +42,7 @@ class EchoRemoverImpl final : public EchoRemover { // TODO(peah): Add functionality. EchoRemoverImpl::EchoRemoverImpl(int sample_rate_hz) : sample_rate_hz_(sample_rate_hz) { - RTC_DCHECK(sample_rate_hz == 8000 || sample_rate_hz == 16000 || - sample_rate_hz == 32000 || sample_rate_hz == 48000); + RTC_DCHECK(ValidFullBandRate(sample_rate_hz_)); } EchoRemoverImpl::~EchoRemoverImpl() = default; diff --git a/webrtc/modules/audio_processing/aec3/matched_filter.cc b/webrtc/modules/audio_processing/aec3/matched_filter.cc new file mode 100644 index 0000000000..f187159911 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/matched_filter.cc @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "webrtc/modules/audio_processing/aec3/matched_filter.h" + +#include +#include + +#include "webrtc/modules/audio_processing/include/audio_processing.h" +#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" + +namespace webrtc { + +MatchedFilter::IndexedBuffer::IndexedBuffer(size_t size) : data(size, 0.f) { + RTC_DCHECK_EQ(0, size % kSubBlockSize); +} + +MatchedFilter::IndexedBuffer::~IndexedBuffer() = default; + +MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper, + size_t window_size_sub_blocks, + int num_matched_filters, + size_t alignment_shift_sub_blocks) + : data_dumper_(data_dumper), + filter_intra_lag_shift_(alignment_shift_sub_blocks * kSubBlockSize), + filters_(num_matched_filters, + std::vector(window_size_sub_blocks * kSubBlockSize, 0.f)), + lag_estimates_(num_matched_filters), + x_buffer_(kSubBlockSize * + (alignment_shift_sub_blocks * num_matched_filters + + window_size_sub_blocks + + 1)) { + RTC_DCHECK(data_dumper); + RTC_DCHECK_EQ(0, x_buffer_.data.size() % kSubBlockSize); + RTC_DCHECK_LT(0, window_size_sub_blocks); +} + +MatchedFilter::~MatchedFilter() = default; + +void MatchedFilter::Update(const std::array& render, + const std::array& capture) { + const std::array& x = render; + const std::array& y = capture; + + const float x2_sum_threshold = filters_[0].size() * 150.f * 150.f; + + // Insert the new subblock into x_buffer. + x_buffer_.index = (x_buffer_.index - kSubBlockSize + x_buffer_.data.size()) % + x_buffer_.data.size(); + RTC_DCHECK_LE(kSubBlockSize, x_buffer_.data.size() - x_buffer_.index); + std::copy(x.rbegin(), x.rend(), x_buffer_.data.begin() + x_buffer_.index); + + // Apply all matched filters. + size_t alignment_shift = 0; + for (size_t n = 0; n < filters_.size(); ++n) { + float error_sum = 0.f; + bool filters_updated = false; + size_t x_start_index = + (x_buffer_.index + alignment_shift + kSubBlockSize - 1) % + x_buffer_.data.size(); + + // Process for all samples in the sub-block. + for (size_t i = 0; i < kSubBlockSize; ++i) { + // As x_buffer is a circular buffer, all of the processing is split into + // two loops around the wrapping of the buffer. + const size_t loop_size_1 = + std::min(filters_[n].size(), x_buffer_.data.size() - x_start_index); + const size_t loop_size_2 = filters_[n].size() - loop_size_1; + RTC_DCHECK_EQ(filters_[n].size(), loop_size_1 + loop_size_2); + + // x * x. + float x2_sum = std::inner_product( + x_buffer_.data.begin() + x_start_index, + x_buffer_.data.begin() + x_start_index + loop_size_1, + x_buffer_.data.begin() + x_start_index, 0.f); + // Apply the matched filter as filter * x. + float s = std::inner_product(filters_[n].begin(), + filters_[n].begin() + loop_size_1, + x_buffer_.data.begin() + x_start_index, 0.f); + + if (loop_size_2 > 0) { + // Update the cumulative sum of x * x. + x2_sum = std::inner_product(x_buffer_.data.begin(), + x_buffer_.data.begin() + loop_size_2, + x_buffer_.data.begin(), x2_sum); + + // Compute the matched filter output filter * x in a cumulative manner. + s = std::inner_product(x_buffer_.data.begin(), + x_buffer_.data.begin() + loop_size_2, + filters_[n].begin() + loop_size_1, s); + } + + // Compute the matched filter error. + const float e = std::min(32767.f, std::max(-32768.f, y[i] - s)); + error_sum += e * e; + + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold) { + filters_updated = true; + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = 0.7f * e / x2_sum; + + // filter = filter + 0.7 * (y - filter * x) / x * x. + std::transform(filters_[n].begin(), filters_[n].begin() + loop_size_1, + x_buffer_.data.begin() + x_start_index, + filters_[n].begin(), + [&](float a, float b) { return a + alpha * b; }); + + if (loop_size_2 > 0) { + // filter = filter + 0.7 * (y - filter * x) / x * x. + std::transform(x_buffer_.data.begin(), + x_buffer_.data.begin() + loop_size_2, + filters_[n].begin() + loop_size_1, + filters_[n].begin() + loop_size_1, + [&](float a, float b) { return b + alpha * a; }); + } + } + + x_start_index = + x_start_index > 0 ? x_start_index - 1 : x_buffer_.data.size() - 1; + } + + // Compute anchor for the matched filter error. + const float error_sum_anchor = + std::inner_product(y.begin(), y.end(), y.begin(), 0.f); + + // Estimate the lag in the matched filter as the distance to the portion in + // the filter that contributes the most to the matched filter output. This + // is detected as the peak of the matched filter. + const size_t lag_estimate = std::distance( + filters_[n].begin(), + std::max_element( + filters_[n].begin(), filters_[n].end(), + [](float a, float b) -> bool { return a * a < b * b; })); + + // Update the lag estimates for the matched filter. + const float kMatchingFilterThreshold = 0.3f; + lag_estimates_[n] = + LagEstimate(error_sum_anchor - error_sum, + error_sum < kMatchingFilterThreshold * error_sum_anchor, + lag_estimate + alignment_shift, filters_updated); + + // TODO(peah): Remove once development of EchoCanceller3 is fully done. + RTC_DCHECK_EQ(4, filters_.size()); + switch (n) { + case 0: + data_dumper_->DumpRaw("aec3_correlator_0_h", filters_[0]); + break; + case 1: + data_dumper_->DumpRaw("aec3_correlator_1_h", filters_[1]); + break; + case 2: + data_dumper_->DumpRaw("aec3_correlator_2_h", filters_[2]); + break; + case 3: + data_dumper_->DumpRaw("aec3_correlator_3_h", filters_[3]); + break; + default: + RTC_DCHECK(false); + } + + alignment_shift += filter_intra_lag_shift_; + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/matched_filter.h b/webrtc/modules/audio_processing/aec3/matched_filter.h new file mode 100644 index 0000000000..3e09d4b971 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/matched_filter.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_ + +#include +#include +#include + +#include "webrtc/base/constructormagic.h" +#include "webrtc/base/optional.h" +#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" + +namespace webrtc { + +class ApmDataDumper; + +// Produces recursively updated cross-correlation estimates for several signal +// shifts where the intra-shift spacing is uniform. +class MatchedFilter { + public: + // Stores properties for the lag estimate corresponding to a particular signal + // shift. + struct LagEstimate { + LagEstimate() = default; + LagEstimate(float accuracy, bool reliable, size_t lag, bool updated) + : accuracy(accuracy), reliable(reliable), lag(lag), updated(updated) {} + + float accuracy = 0.f; + bool reliable = false; + size_t lag = 0; + bool updated = false; + }; + + MatchedFilter(ApmDataDumper* data_dumper, + size_t window_size_sub_blocks, + int num_matched_filters, + size_t alignment_shift_sub_blocks); + + ~MatchedFilter(); + + // Updates the correlation with the values in render and capture. + void Update(const std::array& render, + const std::array& capture); + + // Returns the current lag estimates. + rtc::ArrayView GetLagEstimates() const { + return lag_estimates_; + } + + // Returns the number of lag estimates produced using the shifted signals. + size_t NumLagEstimates() const { return filters_.size(); } + + private: + // Provides buffer with a related index. + struct IndexedBuffer { + explicit IndexedBuffer(size_t size); + ~IndexedBuffer(); + + std::vector data; + int index = 0; + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(IndexedBuffer); + }; + + ApmDataDumper* const data_dumper_; + const size_t filter_intra_lag_shift_; + std::vector> filters_; + std::vector lag_estimates_; + IndexedBuffer x_buffer_; + + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(MatchedFilter); +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_ diff --git a/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc b/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc new file mode 100644 index 0000000000..d9176efa94 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h" + +#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" + +namespace webrtc { + +MatchedFilterLagAggregator::MatchedFilterLagAggregator( + ApmDataDumper* data_dumper, + size_t num_lag_estimates) + : data_dumper_(data_dumper), lag_updates_in_a_row_(num_lag_estimates, 0) { + RTC_DCHECK(data_dumper); + RTC_DCHECK_LT(0, num_lag_estimates); +} + +MatchedFilterLagAggregator::~MatchedFilterLagAggregator() = default; + +rtc::Optional MatchedFilterLagAggregator::Aggregate( + rtc::ArrayView lag_estimates) { + RTC_DCHECK_EQ(lag_updates_in_a_row_.size(), lag_estimates.size()); + + // Count the number of lag updates in a row to ensure that only stable lags + // are taken into account. + for (size_t k = 0; k < lag_estimates.size(); ++k) { + lag_updates_in_a_row_[k] = + lag_estimates[k].updated ? lag_updates_in_a_row_[k] + 1 : 0; + } + + // If available, choose the strongest lag estimate as the best one. + int best_lag_estimate_index = -1; + for (size_t k = 0; k < lag_estimates.size(); ++k) { + if (lag_updates_in_a_row_[k] > 10 && lag_estimates[k].reliable && + (best_lag_estimate_index == -1 || + lag_estimates[k].accuracy > + lag_estimates[best_lag_estimate_index].accuracy)) { + best_lag_estimate_index = k; + } + } + + // TODO(peah): Remove this logging once all development is done. + data_dumper_->DumpRaw("aec3_echo_path_delay_estimator_best_index", + best_lag_estimate_index); + + // Require the same lag to be detected 10 times in a row before considering + // it reliable. + if (best_lag_estimate_index >= 0) { + candidate_counter_ = + (candidate_ == lag_estimates[best_lag_estimate_index].lag) + ? candidate_counter_ + 1 + : 0; + candidate_ = lag_estimates[best_lag_estimate_index].lag; + } + + return candidate_counter_ >= 10 ? rtc::Optional(candidate_) + : rtc::Optional(); +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h b/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h new file mode 100644 index 0000000000..ce8a3d67a0 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_LAG_AGGREGATOR_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_LAG_AGGREGATOR_H_ + +#include + +#include "webrtc/base/constructormagic.h" +#include "webrtc/base/optional.h" +#include "webrtc/modules/audio_processing/aec3/matched_filter.h" + +namespace webrtc { + +class ApmDataDumper; + +// Aggregates lag estimates produced by the MatchedFilter class into a single +// reliable combined lag estimate. +class MatchedFilterLagAggregator { + public: + MatchedFilterLagAggregator(ApmDataDumper* data_dumper, + size_t num_lag_estimates); + ~MatchedFilterLagAggregator(); + + // Aggregates the provided lag estimates. + rtc::Optional Aggregate( + rtc::ArrayView lag_estimates); + + private: + ApmDataDumper* const data_dumper_; + std::vector lag_updates_in_a_row_; + size_t candidate_ = 0; + size_t candidate_counter_ = 0; + + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(MatchedFilterLagAggregator); +}; +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_LAG_AGGREGATOR_H_ diff --git a/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc b/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc new file mode 100644 index 0000000000..b76116ba0b --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h" + +#include +#include +#include + +#include "webrtc/base/array_view.h" +#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { +namespace { + +void VerifyNoAggregateOutputForRepeatedLagAggregation( + size_t num_repetitions, + rtc::ArrayView lag_estimates, + MatchedFilterLagAggregator* aggregator) { + for (size_t k = 0; k < num_repetitions; ++k) { + EXPECT_FALSE(aggregator->Aggregate(lag_estimates)); + } +} + +constexpr size_t kThresholdForRequiredLagUpdatesInARow = 10; +constexpr size_t kThresholdForRequiredIdenticalLagAggregates = 10; + +} // namespace + +// Verifies that the most accurate lag estimate is chosen. +TEST(MatchedFilterLagAggregator, MostAccurateLagChosen) { + constexpr size_t kArtificialLag1 = 5; + constexpr size_t kArtificialLag2 = 10; + ApmDataDumper data_dumper(0); + std::vector lag_estimates(2); + MatchedFilterLagAggregator aggregator(&data_dumper, lag_estimates.size()); + lag_estimates[0] = + MatchedFilter::LagEstimate(1.f, true, kArtificialLag1, true); + lag_estimates[1] = + MatchedFilter::LagEstimate(0.5f, true, kArtificialLag2, true); + + VerifyNoAggregateOutputForRepeatedLagAggregation( + kThresholdForRequiredLagUpdatesInARow + + kThresholdForRequiredIdenticalLagAggregates, + lag_estimates, &aggregator); + rtc::Optional aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag1, *aggregated_lag); + + lag_estimates[0] = + MatchedFilter::LagEstimate(0.5f, true, kArtificialLag1, true); + lag_estimates[1] = + MatchedFilter::LagEstimate(1.f, true, kArtificialLag2, true); + + VerifyNoAggregateOutputForRepeatedLagAggregation( + kThresholdForRequiredIdenticalLagAggregates, lag_estimates, &aggregator); + aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag2, *aggregated_lag); +} + +// Verifies that varying lag estimates causes lag estimates to not be deemed +// reliable. +TEST(MatchedFilterLagAggregator, + LagEstimateInvarianceRequiredForAggregatedLag) { + constexpr size_t kArtificialLag1 = 5; + constexpr size_t kArtificialLag2 = 10; + ApmDataDumper data_dumper(0); + std::vector lag_estimates(1); + MatchedFilterLagAggregator aggregator(&data_dumper, lag_estimates.size()); + lag_estimates[0] = + MatchedFilter::LagEstimate(1.f, true, kArtificialLag1, true); + VerifyNoAggregateOutputForRepeatedLagAggregation( + kThresholdForRequiredLagUpdatesInARow + + kThresholdForRequiredIdenticalLagAggregates, + lag_estimates, &aggregator); + rtc::Optional aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag1, *aggregated_lag); + + lag_estimates[0] = + MatchedFilter::LagEstimate(1.f, true, kArtificialLag2, true); + + VerifyNoAggregateOutputForRepeatedLagAggregation( + kThresholdForRequiredIdenticalLagAggregates, lag_estimates, &aggregator); + aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag2, *aggregated_lag); +} + +// Verifies that lag estimate updates are required to produce an updated lag +// aggregate. +TEST(MatchedFilterLagAggregator, LagEstimateUpdatesRequiredForAggregatedLag) { + constexpr size_t kArtificialLag1 = 5; + constexpr size_t kArtificialLag2 = 10; + ApmDataDumper data_dumper(0); + std::vector lag_estimates(1); + MatchedFilterLagAggregator aggregator(&data_dumper, lag_estimates.size()); + lag_estimates[0] = + MatchedFilter::LagEstimate(1.f, true, kArtificialLag1, true); + VerifyNoAggregateOutputForRepeatedLagAggregation( + kThresholdForRequiredLagUpdatesInARow + + kThresholdForRequiredIdenticalLagAggregates, + lag_estimates, &aggregator); + rtc::Optional aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag1, *aggregated_lag); + + lag_estimates[0] = + MatchedFilter::LagEstimate(1.f, true, kArtificialLag2, false); + + for (size_t k = 0; k < kThresholdForRequiredLagUpdatesInARow + + kThresholdForRequiredIdenticalLagAggregates + 1; + ++k) { + aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag1, *aggregated_lag); + } + + lag_estimates[0] = + MatchedFilter::LagEstimate(1.f, true, kArtificialLag2, true); + for (size_t k = 0; k < kThresholdForRequiredLagUpdatesInARow; ++k) { + aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag1, *aggregated_lag); + } + + VerifyNoAggregateOutputForRepeatedLagAggregation( + kThresholdForRequiredIdenticalLagAggregates, lag_estimates, &aggregator); + + aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag2, *aggregated_lag); +} + +// Verifies that an aggregated lag is persistent if the lag estimates do not +// change and that an aggregated lag is not produced without gaining lag +// estimate confidence. +TEST(MatchedFilterLagAggregator, PersistentAggregatedLag) { + constexpr size_t kArtificialLag = 5; + ApmDataDumper data_dumper(0); + std::vector lag_estimates(1); + MatchedFilterLagAggregator aggregator(&data_dumper, lag_estimates.size()); + lag_estimates[0] = + MatchedFilter::LagEstimate(1.f, true, kArtificialLag, true); + VerifyNoAggregateOutputForRepeatedLagAggregation( + kThresholdForRequiredLagUpdatesInARow + + kThresholdForRequiredIdenticalLagAggregates, + lag_estimates, &aggregator); + rtc::Optional aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag, *aggregated_lag); + + aggregated_lag = aggregator.Aggregate(lag_estimates); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kArtificialLag, *aggregated_lag); +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies the check for correct number of lag estimates. +TEST(MatchedFilterLagAggregator, IncorrectNumberOfLagEstimates) { + ApmDataDumper data_dumper(0); + MatchedFilterLagAggregator aggregator(&data_dumper, 1); + std::vector lag_estimates(2); + + EXPECT_DEATH(aggregator.Aggregate(lag_estimates), ""); +} + +// Verifies the check for non-zero number of lag estimates. +TEST(MatchedFilterLagAggregator, NonZeroLagEstimates) { + ApmDataDumper data_dumper(0); + EXPECT_DEATH(MatchedFilterLagAggregator(&data_dumper, 0), ""); +} + +// Verifies the check for non-null data dumper. +TEST(MatchedFilterLagAggregator, NullDataDumper) { + EXPECT_DEATH(MatchedFilterLagAggregator(nullptr, 1), ""); +} + +#endif + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc b/webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc new file mode 100644 index 0000000000..993ebc8b92 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/matched_filter.h" + +#include +#include +#include + +#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" +#include "webrtc/modules/audio_processing/test/echo_canceller_test_tools.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { +namespace { + +std::string ProduceDebugText(size_t delay) { + std::ostringstream ss; + ss << "Delay: " << delay; + return ss.str(); +} + +constexpr size_t kWindowSizeSubBlocks = 32; +constexpr size_t kAlignmentShiftSubBlocks = kWindowSizeSubBlocks * 3 / 4; +constexpr size_t kNumMatchedFilters = 4; + +} // namespace + +// Verifies that the matched filter produces proper lag estimates for +// artificially +// delayed signals. +TEST(MatchedFilter, LagEstimation) { + Random random_generator(42U); + std::array render; + std::array capture; + render.fill(0.f); + capture.fill(0.f); + ApmDataDumper data_dumper(0); + for (size_t delay_samples : {0, 64, 150, 200, 800, 1000}) { + SCOPED_TRACE(ProduceDebugText(delay_samples)); + DelayBuffer signal_delay_buffer(delay_samples); + MatchedFilter filter(&data_dumper, kWindowSizeSubBlocks, kNumMatchedFilters, + kAlignmentShiftSubBlocks); + + // Analyze the correlation between render and capture. + for (size_t k = 0; k < (100 + delay_samples / kSubBlockSize); ++k) { + RandomizeSampleVector(&random_generator, render); + signal_delay_buffer.Delay(render, capture); + filter.Update(render, capture); + } + + // Obtain the lag estimates. + auto lag_estimates = filter.GetLagEstimates(); + + // Find which lag estimate should be the most accurate. + rtc::Optional expected_most_accurate_lag_estimate; + size_t alignment_shift_sub_blocks = 0; + for (size_t k = 0; k < kNumMatchedFilters; ++k) { + if ((alignment_shift_sub_blocks + kWindowSizeSubBlocks / 2) * + kSubBlockSize > + delay_samples) { + expected_most_accurate_lag_estimate = rtc::Optional(k); + break; + } + alignment_shift_sub_blocks += kAlignmentShiftSubBlocks; + } + ASSERT_TRUE(expected_most_accurate_lag_estimate); + + // Verify that the expected most accurate lag estimate is the most accurate + // estimate. + for (size_t k = 0; k < kNumMatchedFilters; ++k) { + if (k != *expected_most_accurate_lag_estimate) { + EXPECT_GT(lag_estimates[*expected_most_accurate_lag_estimate].accuracy, + lag_estimates[k].accuracy); + } + } + + // Verify that all lag estimates are updated as expected for signals + // containing strong noise. + for (auto& le : lag_estimates) { + EXPECT_TRUE(le.updated); + } + + // Verify that the expected most accurate lag estimate is reliable. + EXPECT_TRUE(lag_estimates[*expected_most_accurate_lag_estimate].reliable); + + // Verify that the expected most accurate lag estimate is correct. + EXPECT_EQ(delay_samples, + lag_estimates[*expected_most_accurate_lag_estimate].lag); + } +} + +// Verifies that the matched filter does not produce reliable and accurate +// estimates for uncorrelated render and capture signals. +TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) { + Random random_generator(42U); + std::array render; + std::array capture; + render.fill(0.f); + capture.fill(0.f); + ApmDataDumper data_dumper(0); + MatchedFilter filter(&data_dumper, kWindowSizeSubBlocks, kNumMatchedFilters, + kAlignmentShiftSubBlocks); + + // Analyze the correlation between render and capture. + for (size_t k = 0; k < 100; ++k) { + RandomizeSampleVector(&random_generator, render); + RandomizeSampleVector(&random_generator, capture); + filter.Update(render, capture); + } + + // Obtain the lag estimates. + auto lag_estimates = filter.GetLagEstimates(); + EXPECT_EQ(kNumMatchedFilters, lag_estimates.size()); + + // Verify that no lag estimates are reliable. + for (auto& le : lag_estimates) { + EXPECT_FALSE(le.reliable); + } +} + +// Verifies that the matched filter does not produce updated lag estimates for +// render signals of low level. +TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) { + Random random_generator(42U); + std::array render; + std::array capture; + render.fill(0.f); + capture.fill(0.f); + ApmDataDumper data_dumper(0); + MatchedFilter filter(&data_dumper, kWindowSizeSubBlocks, kNumMatchedFilters, + kAlignmentShiftSubBlocks); + + // Analyze the correlation between render and capture. + for (size_t k = 0; k < 100; ++k) { + RandomizeSampleVector(&random_generator, render); + for (auto& render_k : render) { + render_k *= 149.f / 32767.f; + } + std::copy(render.begin(), render.end(), capture.begin()); + filter.Update(render, capture); + } + + // Obtain the lag estimates. + auto lag_estimates = filter.GetLagEstimates(); + EXPECT_EQ(kNumMatchedFilters, lag_estimates.size()); + + // Verify that no lag estimates are updated and that no lag estimates are + // reliable. + for (auto& le : lag_estimates) { + EXPECT_FALSE(le.updated); + EXPECT_FALSE(le.reliable); + } +} + +// Verifies that the correct number of lag estimates are produced for a certain +// number of alignment shifts. +TEST(MatchedFilter, NumberOfLagEstimates) { + ApmDataDumper data_dumper(0); + for (size_t num_matched_filters = 0; num_matched_filters < 10; + ++num_matched_filters) { + MatchedFilter filter(&data_dumper, 32, num_matched_filters, 1); + EXPECT_EQ(num_matched_filters, filter.GetLagEstimates().size()); + } +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies the check for non-zero windows size. +TEST(MatchedFilter, ZeroWindowSize) { + ApmDataDumper data_dumper(0); + EXPECT_DEATH(MatchedFilter(&data_dumper, 0, 1, 1), ""); +} + +// Verifies the check for non-null data dumper. +TEST(MatchedFilter, NullDataDumper) { + EXPECT_DEATH(MatchedFilter(nullptr, 1, 1, 1), ""); +} + +#endif + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/render_delay_controller.cc b/webrtc/modules/audio_processing/aec3/render_delay_controller.cc index e18701e5ea..981b744eef 100644 --- a/webrtc/modules/audio_processing/aec3/render_delay_controller.cc +++ b/webrtc/modules/audio_processing/aec3/render_delay_controller.cc @@ -115,9 +115,8 @@ RenderDelayControllerImpl::RenderDelayControllerImpl( max_delay_(render_delay_buffer.MaxDelay()), delay_(render_delay_buffer.Delay()), render_buffer_(render_delay_buffer.MaxApiJitter() + 1), - delay_estimator_(data_dumper_.get(), sample_rate_hz) { - RTC_DCHECK(sample_rate_hz == 8000 || sample_rate_hz == 16000 || - sample_rate_hz == 32000 || sample_rate_hz == 48000); + delay_estimator_(data_dumper_.get()) { + RTC_DCHECK(ValidFullBandRate(sample_rate_hz)); } RenderDelayControllerImpl::~RenderDelayControllerImpl() = default; diff --git a/webrtc/modules/audio_processing/aec3/render_delay_controller_unittest.cc b/webrtc/modules/audio_processing/aec3/render_delay_controller_unittest.cc index 145f3a3042..9d382adb0f 100644 --- a/webrtc/modules/audio_processing/aec3/render_delay_controller_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/render_delay_controller_unittest.cc @@ -105,8 +105,7 @@ TEST(RenderDelayController, BasicApiCalls) { // Verifies that the RenderDelayController is able to align the signals for // simple timeshifts between the signals. -// TODO(peah): Activate the unittest once the required code has been landed. -TEST(RenderDelayController, DISABLED_Alignment) { +TEST(RenderDelayController, Alignment) { Random random_generator(42U); std::vector render_block(kBlockSize, 0.f); std::vector capture_block(kBlockSize, 0.f); @@ -148,8 +147,7 @@ TEST(RenderDelayController, DISABLED_Alignment) { // Verifies that the RenderDelayController is able to align the signals for // simple timeshifts between the signals when there is jitter in the API calls. -// TODO(peah): Activate the unittest once the required code has been landed. -TEST(RenderDelayController, DISABLED_AlignmentWithJitter) { +TEST(RenderDelayController, AlignmentWithJitter) { Random random_generator(42U); std::vector render_block(kBlockSize, 0.f); std::vector capture_block(kBlockSize, 0.f);