Adding full initial version of delay estimation functionality in echo

canceller 3

This CL adds code to the all the delay estimation functionality that is
available for the first version of echo canceller 3. The code completes
the class EchoPathDelayEstimator.

Note that this code does not yet include any handling of clock-drift so
there will be upcoming versions of this code.

Also note that the CL includes some minor changes in other files for
echo canceller 3.

BUG=webrtc:6018

Review-Url: https://codereview.webrtc.org/2644123002
Cr-Commit-Position: refs/heads/master@{#16489}
This commit is contained in:
peah 2017-02-08 05:08:56 -08:00 committed by Commit bot
parent d4ed7f59e4
commit 219208991b
21 changed files with 1171 additions and 54 deletions

View File

@ -33,6 +33,8 @@ rtc_static_library("audio_processing") {
"aec3/block_processor.h",
"aec3/cascaded_biquad_filter.cc",
"aec3/cascaded_biquad_filter.h",
"aec3/decimator_by_4.cc",
"aec3/decimator_by_4.h",
"aec3/echo_canceller3.cc",
"aec3/echo_canceller3.h",
"aec3/echo_path_delay_estimator.cc",
@ -42,6 +44,10 @@ rtc_static_library("audio_processing") {
"aec3/echo_remover.h",
"aec3/frame_blocker.cc",
"aec3/frame_blocker.h",
"aec3/matched_filter.cc",
"aec3/matched_filter.h",
"aec3/matched_filter_lag_aggregator.cc",
"aec3/matched_filter_lag_aggregator.h",
"aec3/render_delay_buffer.cc",
"aec3/render_delay_buffer.h",
"aec3/render_delay_controller.cc",
@ -519,10 +525,13 @@ if (rtc_include_tests) {
"aec3/block_framer_unittest.cc",
"aec3/block_processor_unittest.cc",
"aec3/cascaded_biquad_filter_unittest.cc",
"aec3/decimator_by_4_unittest.cc",
"aec3/echo_canceller3_unittest.cc",
"aec3/echo_path_delay_estimator_unittest.cc",
"aec3/echo_remover_unittest.cc",
"aec3/frame_blocker_unittest.cc",
"aec3/matched_filter_lag_aggregator_unittest.cc",
"aec3/matched_filter_unittest.cc",
"aec3/mock/mock_block_processor.h",
"aec3/mock/mock_echo_remover.h",
"aec3/mock/mock_render_delay_buffer.h",

View File

@ -34,6 +34,11 @@ constexpr int LowestBandRate(int sample_rate_hz) {
return sample_rate_hz == 8000 ? sample_rate_hz : 16000;
}
constexpr bool ValidFullBandRate(int sample_rate_hz) {
return sample_rate_hz == 8000 || sample_rate_hz == 16000 ||
sample_rate_hz == 32000 || sample_rate_hz == 48000;
}
static_assert(1 == NumBandsForRate(8000), "Number of bands for 8 kHz");
static_assert(1 == NumBandsForRate(16000), "Number of bands for 16 kHz");
static_assert(2 == NumBandsForRate(32000), "Number of bands for 32 kHz");
@ -47,6 +52,17 @@ static_assert(16000 == LowestBandRate(32000),
static_assert(16000 == LowestBandRate(48000),
"Sample rate of band 0 for 48 kHz");
static_assert(ValidFullBandRate(8000),
"Test that 8 kHz is a valid sample rate");
static_assert(ValidFullBandRate(16000),
"Test that 16 kHz is a valid sample rate");
static_assert(ValidFullBandRate(32000),
"Test that 32 kHz is a valid sample rate");
static_assert(ValidFullBandRate(48000),
"Test that 48 kHz is a valid sample rate");
static_assert(!ValidFullBandRate(8001),
"Test that 8001 Hz is not a valid sample rate");
} // namespace webrtc
#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_AEC3_CONSTANTS_H_

View File

@ -61,7 +61,9 @@ BlockProcessorImpl::BlockProcessorImpl(
sample_rate_hz_(sample_rate_hz),
render_buffer_(std::move(render_buffer)),
delay_controller_(std::move(delay_controller)),
echo_remover_(std::move(echo_remover)) {}
echo_remover_(std::move(echo_remover)) {
RTC_DCHECK(ValidFullBandRate(sample_rate_hz_));
}
BlockProcessorImpl::~BlockProcessorImpl() = default;

View File

@ -252,6 +252,12 @@ TEST(BlockProcessor, NullBufferRenderParameter) {
"");
}
// Verifies the check for correct sample rate.
TEST(BlockProcessor, WrongSampleRate) {
EXPECT_DEATH(std::unique_ptr<BlockProcessor>(BlockProcessor::Create(8001)),
"");
}
#endif
} // namespace webrtc

View File

@ -0,0 +1,44 @@
/*
* Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "webrtc/modules/audio_processing/aec3/decimator_by_4.h"
#include "webrtc/base/checks.h"
namespace webrtc {
namespace {
// [B,A] = butter(2,1500/16000) which are the same as [B,A] =
// butter(2,750/8000).
const CascadedBiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients = {
{0.0179f, 0.0357f, 0.0179f},
{-1.5879f, 0.6594f}};
} // namespace
DecimatorBy4::DecimatorBy4()
: low_pass_filter_(kLowPassFilterCoefficients, 3) {}
void DecimatorBy4::Decimate(rtc::ArrayView<const float> in,
std::array<float, kSubBlockSize>* out) {
RTC_DCHECK_EQ(kBlockSize, in.size());
RTC_DCHECK(out);
std::array<float, kBlockSize> x;
// Limit the frequency content of the signal to avoid aliasing.
low_pass_filter_.Process(in, x);
// Downsample the signal.
for (size_t j = 0, k = 0; j < out->size(); ++j, k += 4) {
RTC_DCHECK_GT(kBlockSize, k);
(*out)[j] = x[k];
}
}
} // namespace webrtc

View File

@ -0,0 +1,39 @@
/*
* Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_DECIMATOR_BY_4_H_
#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_DECIMATOR_BY_4_H_
#include <array>
#include "webrtc/base/array_view.h"
#include "webrtc/base/constructormagic.h"
#include "webrtc/modules/audio_processing/aec3/aec3_constants.h"
#include "webrtc/modules/audio_processing/aec3/cascaded_biquad_filter.h"
namespace webrtc {
// Provides functionality for decimating a signal by 4.
class DecimatorBy4 {
public:
DecimatorBy4();
// Downsamples the signal.
void Decimate(rtc::ArrayView<const float> in,
std::array<float, kSubBlockSize>* out);
private:
CascadedBiQuadFilter low_pass_filter_;
RTC_DISALLOW_COPY_AND_ASSIGN(DecimatorBy4);
};
} // namespace webrtc
#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_DECIMATOR_BY_4_H_

View File

@ -0,0 +1,127 @@
/*
* Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "webrtc/modules/audio_processing/aec3/decimator_by_4.h"
#include <math.h>
#include <algorithm>
#include <array>
#include <numeric>
#include <sstream>
#include <string>
#include <vector>
#include "webrtc/modules/audio_processing/aec3/aec3_constants.h"
#include "webrtc/test/gtest.h"
namespace webrtc {
namespace {
std::string ProduceDebugText(int sample_rate_hz) {
std::ostringstream ss;
ss << "Sample rate: " << sample_rate_hz;
return ss.str();
}
constexpr float kPi = 3.141592f;
constexpr size_t kNumStartupBlocks = 50;
constexpr size_t kNumBlocks = 1000;
void ProduceDecimatedSinusoidalOutputPower(int sample_rate_hz,
float sinusoidal_frequency_hz,
float* input_power,
float* output_power) {
float input[kBlockSize * kNumBlocks];
// Produce a sinusoid of the specified frequency.
for (size_t k = 0; k < kBlockSize * kNumBlocks; ++k) {
input[k] =
32767.f * sin(2.f * kPi * sinusoidal_frequency_hz * k / sample_rate_hz);
}
DecimatorBy4 decimator;
std::array<float, kSubBlockSize * kNumBlocks> output;
for (size_t k = 0; k < kNumBlocks; ++k) {
std::array<float, kSubBlockSize> sub_block;
decimator.Decimate(
rtc::ArrayView<const float>(&input[k * kBlockSize], kBlockSize),
&sub_block);
std::copy(sub_block.begin(), sub_block.end(),
output.begin() + k * kSubBlockSize);
}
ASSERT_GT(kNumBlocks, kNumStartupBlocks);
rtc::ArrayView<const float> input_to_evaluate(
&input[kNumStartupBlocks * kBlockSize],
(kNumBlocks - kNumStartupBlocks) * kBlockSize);
rtc::ArrayView<const float> output_to_evaluate(
&output[kNumStartupBlocks * kSubBlockSize],
(kNumBlocks - kNumStartupBlocks) * kSubBlockSize);
*input_power =
std::inner_product(input_to_evaluate.begin(), input_to_evaluate.end(),
input_to_evaluate.begin(), 0.f) /
input_to_evaluate.size();
*output_power =
std::inner_product(output_to_evaluate.begin(), output_to_evaluate.end(),
output_to_evaluate.begin(), 0.f) /
output_to_evaluate.size();
}
} // namespace
// Verifies that there is little aliasing from upper frequencies in the
// downsampling.
TEST(DecimatorBy4, NoLeakageFromUpperFrequencies) {
float input_power;
float output_power;
for (auto rate : {8000, 16000, 32000, 48000}) {
ProduceDebugText(rate);
ProduceDecimatedSinusoidalOutputPower(rate, 3.f / 8.f * rate, &input_power,
&output_power);
EXPECT_GT(0.0001f * input_power, output_power);
}
}
// Verifies that the impact of low-frequency content is small during the
// downsampling.
TEST(DecimatorBy4, NoImpactOnLowerFrequencies) {
float input_power;
float output_power;
for (auto rate : {8000, 16000, 32000, 48000}) {
ProduceDebugText(rate);
ProduceDecimatedSinusoidalOutputPower(rate, 200.f, &input_power,
&output_power);
EXPECT_LT(0.7f * input_power, output_power);
}
}
#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
// Verifies the check for the input size.
TEST(DecimatorBy4, WrongInputSize) {
DecimatorBy4 decimator;
std::vector<float> x(std::vector<float>(kBlockSize - 1, 0.f));
std::array<float, kSubBlockSize> x_downsampled;
EXPECT_DEATH(decimator.Decimate(x, &x_downsampled), "");
}
// Verifies the check for non-null output parameter.
TEST(DecimatorBy4, NullOutput) {
DecimatorBy4 decimator;
std::vector<float> x(std::vector<float>(kBlockSize, 0.f));
EXPECT_DEATH(decimator.Decimate(x, nullptr), "");
}
#endif
} // namespace webrtc

View File

@ -227,6 +227,8 @@ EchoCanceller3::EchoCanceller3(int sample_rate_hz,
std::vector<float>(frame_length_, 0.f)),
block_(num_bands_, std::vector<float>(kBlockSize, 0.f)),
sub_frame_view_(num_bands_) {
RTC_DCHECK(ValidFullBandRate(sample_rate_hz_));
std::unique_ptr<CascadedBiQuadFilter> render_highpass_filter;
if (use_highpass_filter) {
render_highpass_filter.reset(new CascadedBiQuadFilter(

View File

@ -719,6 +719,12 @@ TEST(EchoCanceller3InputCheck, NullCaptureProcessingParameter) {
EXPECT_DEATH(EchoCanceller3(8000, false).ProcessCapture(nullptr, false), "");
}
// Verifies the check for correct sample rate.
TEST(EchoCanceller3InputCheck, WrongSampleRate) {
ApmDataDumper data_dumper(0);
EXPECT_DEATH(EchoCanceller3(8001, false), "");
}
#endif
} // namespace webrtc

View File

@ -9,29 +9,71 @@
*/
#include "webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.h"
#include <algorithm>
#include <array>
#include "webrtc/base/checks.h"
#include "webrtc/modules/audio_processing/aec3/aec3_constants.h"
#include "webrtc/modules/audio_processing/include/audio_processing.h"
#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h"
namespace webrtc {
// TODO(peah): Add functionality.
EchoPathDelayEstimator::EchoPathDelayEstimator(ApmDataDumper* data_dumper,
int sample_rate_hz) {
namespace {
constexpr size_t kNumMatchedFilters = 4;
constexpr size_t kMatchedFilterWindowSizeSubBlocks = 32;
constexpr size_t kMatchedFilterAlignmentShiftSizeSubBlocks =
kMatchedFilterWindowSizeSubBlocks * 3 / 4;
constexpr int kDownSamplingFactor = 4;
} // namespace
EchoPathDelayEstimator::EchoPathDelayEstimator(ApmDataDumper* data_dumper)
: data_dumper_(data_dumper),
matched_filter_(data_dumper_,
kMatchedFilterWindowSizeSubBlocks,
kNumMatchedFilters,
kMatchedFilterAlignmentShiftSizeSubBlocks),
matched_filter_lag_aggregator_(data_dumper_,
matched_filter_.NumLagEstimates()) {
RTC_DCHECK(data_dumper);
RTC_DCHECK(sample_rate_hz == 8000 || sample_rate_hz == 16000 ||
sample_rate_hz == 32000 || sample_rate_hz == 48000);
}
EchoPathDelayEstimator::~EchoPathDelayEstimator() = default;
// TODO(peah): Add functionality.
rtc::Optional<size_t> EchoPathDelayEstimator::EstimateDelay(
rtc::ArrayView<const float> render,
rtc::ArrayView<const float> capture) {
RTC_DCHECK_EQ(render.size(), kBlockSize);
RTC_DCHECK_EQ(capture.size(), kBlockSize);
return rtc::Optional<size_t>();
RTC_DCHECK_EQ(kBlockSize, capture.size());
RTC_DCHECK_EQ(render.size(), capture.size());
std::array<float, kSubBlockSize> downsampled_render;
std::array<float, kSubBlockSize> downsampled_capture;
render_decimator_.Decimate(render, &downsampled_render);
capture_decimator_.Decimate(capture, &downsampled_capture);
matched_filter_.Update(downsampled_render, downsampled_capture);
rtc::Optional<size_t> aggregated_matched_filter_lag =
matched_filter_lag_aggregator_.Aggregate(
matched_filter_.GetLagEstimates());
// TODO(peah): Move this logging outside of this class once EchoCanceller3
// development is done.
data_dumper_->DumpRaw("aec3_echo_path_delay_estimator_delay",
aggregated_matched_filter_lag
? static_cast<int>(*aggregated_matched_filter_lag *
kDownSamplingFactor)
: -1);
// Return the detected delay in samples as the aggregated matched filter lag
// compensated by the down sampling factor for the signal being correlated.
return aggregated_matched_filter_lag
? rtc::Optional<size_t>(*aggregated_matched_filter_lag *
kDownSamplingFactor)
: rtc::Optional<size_t>();
}
} // namespace webrtc

View File

@ -15,19 +15,31 @@
#include "webrtc/base/constructormagic.h"
#include "webrtc/base/optional.h"
#include "webrtc/modules/audio_processing/aec3/matched_filter.h"
#include "webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h"
#include "webrtc/modules/audio_processing/aec3/decimator_by_4.h"
namespace webrtc {
class ApmDataDumper;
// Estimates the delay of the echo path.
class EchoPathDelayEstimator {
public:
EchoPathDelayEstimator(ApmDataDumper* data_dumper, int sample_rate_hz);
explicit EchoPathDelayEstimator(ApmDataDumper* data_dumper);
~EchoPathDelayEstimator();
// Produce a delay estimate if such is avaliable.
rtc::Optional<size_t> EstimateDelay(rtc::ArrayView<const float> render,
rtc::ArrayView<const float> capture);
private:
ApmDataDumper* const data_dumper_;
DecimatorBy4 render_decimator_;
DecimatorBy4 capture_decimator_;
MatchedFilter matched_filter_;
MatchedFilterLagAggregator matched_filter_lag_aggregator_;
RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(EchoPathDelayEstimator);
};
} // namespace webrtc

View File

@ -10,19 +10,22 @@
#include "webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.h"
#include <algorithm>
#include <sstream>
#include <string>
#include "webrtc/base/random.h"
#include "webrtc/modules/audio_processing/aec3/aec3_constants.h"
#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h"
#include "webrtc/modules/audio_processing/test/echo_canceller_test_tools.h"
#include "webrtc/test/gtest.h"
namespace webrtc {
namespace {
std::string ProduceDebugText(int sample_rate_hz) {
std::string ProduceDebugText(size_t delay) {
std::ostringstream ss;
ss << "Sample rate: " << sample_rate_hz;
ss << "Delay: " << delay;
return ss.str();
}
@ -30,57 +33,120 @@ std::string ProduceDebugText(int sample_rate_hz) {
// Verifies that the basic API calls work.
TEST(EchoPathDelayEstimator, BasicApiCalls) {
for (auto rate : {8000, 16000, 32000, 48000}) {
ProduceDebugText(rate);
ApmDataDumper data_dumper(0);
EchoPathDelayEstimator estimator(&data_dumper, rate);
std::vector<float> render(kBlockSize, 0.f);
std::vector<float> capture(kBlockSize, 0.f);
for (size_t k = 0; k < 100; ++k) {
estimator.EstimateDelay(render, capture);
ApmDataDumper data_dumper(0);
EchoPathDelayEstimator estimator(&data_dumper);
std::vector<float> render(kBlockSize, 0.f);
std::vector<float> capture(kBlockSize, 0.f);
for (size_t k = 0; k < 100; ++k) {
estimator.EstimateDelay(render, capture);
}
}
// Verifies that the delay estimator produces correct delay for artificially
// delayed signals.
TEST(EchoPathDelayEstimator, DelayEstimation) {
Random random_generator(42U);
std::vector<float> render(kBlockSize, 0.f);
std::vector<float> capture(kBlockSize, 0.f);
ApmDataDumper data_dumper(0);
for (size_t delay_samples : {0, 64, 150, 200, 800, 4000}) {
SCOPED_TRACE(ProduceDebugText(delay_samples));
DelayBuffer<float> signal_delay_buffer(delay_samples);
EchoPathDelayEstimator estimator(&data_dumper);
rtc::Optional<size_t> estimated_delay_samples;
for (size_t k = 0; k < (100 + delay_samples / kBlockSize); ++k) {
RandomizeSampleVector(&random_generator, render);
signal_delay_buffer.Delay(render, capture);
estimated_delay_samples = estimator.EstimateDelay(render, capture);
}
if (estimated_delay_samples) {
// Due to the internal down-sampling by 4 done inside the delay estimator
// the estimated delay cannot be expected to be closer than 4 samples to
// the true delay.
EXPECT_NEAR(delay_samples, *estimated_delay_samples, 4);
} else {
ADD_FAILURE();
}
}
}
// Verifies that the delay estimator does not produce delay estimates too
// quickly.
TEST(EchoPathDelayEstimator, NoInitialDelayestimates) {
Random random_generator(42U);
std::vector<float> render(kBlockSize, 0.f);
std::vector<float> capture(kBlockSize, 0.f);
ApmDataDumper data_dumper(0);
EchoPathDelayEstimator estimator(&data_dumper);
for (size_t k = 0; k < 19; ++k) {
RandomizeSampleVector(&random_generator, render);
std::copy(render.begin(), render.end(), capture.begin());
EXPECT_FALSE(estimator.EstimateDelay(render, capture));
}
}
// Verifies that the delay estimator does not produce delay estimates for render
// signals of low level.
TEST(EchoPathDelayEstimator, NoDelayEstimatesForLowLevelRenderSignals) {
Random random_generator(42U);
std::vector<float> render(kBlockSize, 0.f);
std::vector<float> capture(kBlockSize, 0.f);
ApmDataDumper data_dumper(0);
EchoPathDelayEstimator estimator(&data_dumper);
for (size_t k = 0; k < 100; ++k) {
RandomizeSampleVector(&random_generator, render);
for (auto& render_k : render) {
render_k *= 100.f / 32767.f;
}
std::copy(render.begin(), render.end(), capture.begin());
EXPECT_FALSE(estimator.EstimateDelay(render, capture));
}
}
// Verifies that the delay estimator does not produce delay estimates for
// uncorrelated signals.
TEST(EchoPathDelayEstimator, NoDelayEstimatesForUncorrelatedSignals) {
Random random_generator(42U);
std::vector<float> render(kBlockSize, 0.f);
std::vector<float> capture(kBlockSize, 0.f);
ApmDataDumper data_dumper(0);
EchoPathDelayEstimator estimator(&data_dumper);
for (size_t k = 0; k < 100; ++k) {
RandomizeSampleVector(&random_generator, render);
RandomizeSampleVector(&random_generator, capture);
EXPECT_FALSE(estimator.EstimateDelay(render, capture));
}
}
#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
// Verifies the check for correct sample rate.
TEST(EchoPathDelayEstimator, WrongSampleRate) {
ApmDataDumper data_dumper(0);
EXPECT_DEATH(EchoPathDelayEstimator remover(&data_dumper, 8001), "");
}
// Verifies the check for the render blocksize.
// TODO(peah): Re-enable the test once the issue with memory leaks during DEATH
// tests on test bots has been fixed.
TEST(EchoPathDelayEstimator, DISABLED_WrongRenderBlockSize) {
for (auto rate : {8000, 16000, 32000, 48000}) {
ProduceDebugText(rate);
ApmDataDumper data_dumper(0);
EchoPathDelayEstimator estimator(&data_dumper, rate);
std::vector<float> render(kBlockSize - 1, 0.f);
std::vector<float> capture(kBlockSize, 0.f);
EXPECT_DEATH(estimator.EstimateDelay(render, capture), "");
}
ApmDataDumper data_dumper(0);
EchoPathDelayEstimator estimator(&data_dumper);
std::vector<float> render(std::vector<float>(kBlockSize - 1, 0.f));
std::vector<float> capture(std::vector<float>(kBlockSize, 0.f));
EXPECT_DEATH(estimator.EstimateDelay(render, capture), "");
}
// Verifies the check for the capture blocksize.
// TODO(peah): Re-enable the test once the issue with memory leaks during DEATH
// tests on test bots has been fixed.
TEST(EchoPathDelayEstimator, DISABLED_WrongCaptureBlockSize) {
for (auto rate : {8000, 16000, 32000, 48000}) {
ProduceDebugText(rate);
ApmDataDumper data_dumper(0);
EchoPathDelayEstimator estimator(&data_dumper, rate);
std::vector<float> render(kBlockSize, 0.f);
std::vector<float> capture(kBlockSize - 1, 0.f);
EXPECT_DEATH(estimator.EstimateDelay(render, capture), "");
}
TEST(EchoPathDelayEstimator, WrongCaptureBlockSize) {
ApmDataDumper data_dumper(0);
EchoPathDelayEstimator estimator(&data_dumper);
std::vector<float> render(std::vector<float>(kBlockSize, 0.f));
std::vector<float> capture(std::vector<float>(kBlockSize - 1, 0.f));
EXPECT_DEATH(estimator.EstimateDelay(render, capture), "");
}
// Verifies the check for non-null data dumper.
TEST(EchoPathDelayEstimator, NullDataDumper) {
EXPECT_DEATH(EchoPathDelayEstimator(nullptr, 8000), "");
EXPECT_DEATH(EchoPathDelayEstimator(nullptr), "");
}
#endif

View File

@ -42,8 +42,7 @@ class EchoRemoverImpl final : public EchoRemover {
// TODO(peah): Add functionality.
EchoRemoverImpl::EchoRemoverImpl(int sample_rate_hz)
: sample_rate_hz_(sample_rate_hz) {
RTC_DCHECK(sample_rate_hz == 8000 || sample_rate_hz == 16000 ||
sample_rate_hz == 32000 || sample_rate_hz == 48000);
RTC_DCHECK(ValidFullBandRate(sample_rate_hz_));
}
EchoRemoverImpl::~EchoRemoverImpl() = default;

View File

@ -0,0 +1,172 @@
/*
* Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "webrtc/modules/audio_processing/aec3/matched_filter.h"
#include <algorithm>
#include <numeric>
#include "webrtc/modules/audio_processing/include/audio_processing.h"
#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h"
namespace webrtc {
MatchedFilter::IndexedBuffer::IndexedBuffer(size_t size) : data(size, 0.f) {
RTC_DCHECK_EQ(0, size % kSubBlockSize);
}
MatchedFilter::IndexedBuffer::~IndexedBuffer() = default;
MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
size_t window_size_sub_blocks,
int num_matched_filters,
size_t alignment_shift_sub_blocks)
: data_dumper_(data_dumper),
filter_intra_lag_shift_(alignment_shift_sub_blocks * kSubBlockSize),
filters_(num_matched_filters,
std::vector<float>(window_size_sub_blocks * kSubBlockSize, 0.f)),
lag_estimates_(num_matched_filters),
x_buffer_(kSubBlockSize *
(alignment_shift_sub_blocks * num_matched_filters +
window_size_sub_blocks +
1)) {
RTC_DCHECK(data_dumper);
RTC_DCHECK_EQ(0, x_buffer_.data.size() % kSubBlockSize);
RTC_DCHECK_LT(0, window_size_sub_blocks);
}
MatchedFilter::~MatchedFilter() = default;
void MatchedFilter::Update(const std::array<float, kSubBlockSize>& render,
const std::array<float, kSubBlockSize>& capture) {
const std::array<float, kSubBlockSize>& x = render;
const std::array<float, kSubBlockSize>& y = capture;
const float x2_sum_threshold = filters_[0].size() * 150.f * 150.f;
// Insert the new subblock into x_buffer.
x_buffer_.index = (x_buffer_.index - kSubBlockSize + x_buffer_.data.size()) %
x_buffer_.data.size();
RTC_DCHECK_LE(kSubBlockSize, x_buffer_.data.size() - x_buffer_.index);
std::copy(x.rbegin(), x.rend(), x_buffer_.data.begin() + x_buffer_.index);
// Apply all matched filters.
size_t alignment_shift = 0;
for (size_t n = 0; n < filters_.size(); ++n) {
float error_sum = 0.f;
bool filters_updated = false;
size_t x_start_index =
(x_buffer_.index + alignment_shift + kSubBlockSize - 1) %
x_buffer_.data.size();
// Process for all samples in the sub-block.
for (size_t i = 0; i < kSubBlockSize; ++i) {
// As x_buffer is a circular buffer, all of the processing is split into
// two loops around the wrapping of the buffer.
const size_t loop_size_1 =
std::min(filters_[n].size(), x_buffer_.data.size() - x_start_index);
const size_t loop_size_2 = filters_[n].size() - loop_size_1;
RTC_DCHECK_EQ(filters_[n].size(), loop_size_1 + loop_size_2);
// x * x.
float x2_sum = std::inner_product(
x_buffer_.data.begin() + x_start_index,
x_buffer_.data.begin() + x_start_index + loop_size_1,
x_buffer_.data.begin() + x_start_index, 0.f);
// Apply the matched filter as filter * x.
float s = std::inner_product(filters_[n].begin(),
filters_[n].begin() + loop_size_1,
x_buffer_.data.begin() + x_start_index, 0.f);
if (loop_size_2 > 0) {
// Update the cumulative sum of x * x.
x2_sum = std::inner_product(x_buffer_.data.begin(),
x_buffer_.data.begin() + loop_size_2,
x_buffer_.data.begin(), x2_sum);
// Compute the matched filter output filter * x in a cumulative manner.
s = std::inner_product(x_buffer_.data.begin(),
x_buffer_.data.begin() + loop_size_2,
filters_[n].begin() + loop_size_1, s);
}
// Compute the matched filter error.
const float e = std::min(32767.f, std::max(-32768.f, y[i] - s));
error_sum += e * e;
// Update the matched filter estimate in an NLMS manner.
if (x2_sum > x2_sum_threshold) {
filters_updated = true;
RTC_DCHECK_LT(0.f, x2_sum);
const float alpha = 0.7f * e / x2_sum;
// filter = filter + 0.7 * (y - filter * x) / x * x.
std::transform(filters_[n].begin(), filters_[n].begin() + loop_size_1,
x_buffer_.data.begin() + x_start_index,
filters_[n].begin(),
[&](float a, float b) { return a + alpha * b; });
if (loop_size_2 > 0) {
// filter = filter + 0.7 * (y - filter * x) / x * x.
std::transform(x_buffer_.data.begin(),
x_buffer_.data.begin() + loop_size_2,
filters_[n].begin() + loop_size_1,
filters_[n].begin() + loop_size_1,
[&](float a, float b) { return b + alpha * a; });
}
}
x_start_index =
x_start_index > 0 ? x_start_index - 1 : x_buffer_.data.size() - 1;
}
// Compute anchor for the matched filter error.
const float error_sum_anchor =
std::inner_product(y.begin(), y.end(), y.begin(), 0.f);
// Estimate the lag in the matched filter as the distance to the portion in
// the filter that contributes the most to the matched filter output. This
// is detected as the peak of the matched filter.
const size_t lag_estimate = std::distance(
filters_[n].begin(),
std::max_element(
filters_[n].begin(), filters_[n].end(),
[](float a, float b) -> bool { return a * a < b * b; }));
// Update the lag estimates for the matched filter.
const float kMatchingFilterThreshold = 0.3f;
lag_estimates_[n] =
LagEstimate(error_sum_anchor - error_sum,
error_sum < kMatchingFilterThreshold * error_sum_anchor,
lag_estimate + alignment_shift, filters_updated);
// TODO(peah): Remove once development of EchoCanceller3 is fully done.
RTC_DCHECK_EQ(4, filters_.size());
switch (n) {
case 0:
data_dumper_->DumpRaw("aec3_correlator_0_h", filters_[0]);
break;
case 1:
data_dumper_->DumpRaw("aec3_correlator_1_h", filters_[1]);
break;
case 2:
data_dumper_->DumpRaw("aec3_correlator_2_h", filters_[2]);
break;
case 3:
data_dumper_->DumpRaw("aec3_correlator_3_h", filters_[3]);
break;
default:
RTC_DCHECK(false);
}
alignment_shift += filter_intra_lag_shift_;
}
}
} // namespace webrtc

View File

@ -0,0 +1,84 @@
/*
* Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_
#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_
#include <array>
#include <memory>
#include <vector>
#include "webrtc/base/constructormagic.h"
#include "webrtc/base/optional.h"
#include "webrtc/modules/audio_processing/aec3/aec3_constants.h"
namespace webrtc {
class ApmDataDumper;
// Produces recursively updated cross-correlation estimates for several signal
// shifts where the intra-shift spacing is uniform.
class MatchedFilter {
public:
// Stores properties for the lag estimate corresponding to a particular signal
// shift.
struct LagEstimate {
LagEstimate() = default;
LagEstimate(float accuracy, bool reliable, size_t lag, bool updated)
: accuracy(accuracy), reliable(reliable), lag(lag), updated(updated) {}
float accuracy = 0.f;
bool reliable = false;
size_t lag = 0;
bool updated = false;
};
MatchedFilter(ApmDataDumper* data_dumper,
size_t window_size_sub_blocks,
int num_matched_filters,
size_t alignment_shift_sub_blocks);
~MatchedFilter();
// Updates the correlation with the values in render and capture.
void Update(const std::array<float, kSubBlockSize>& render,
const std::array<float, kSubBlockSize>& capture);
// Returns the current lag estimates.
rtc::ArrayView<const MatchedFilter::LagEstimate> GetLagEstimates() const {
return lag_estimates_;
}
// Returns the number of lag estimates produced using the shifted signals.
size_t NumLagEstimates() const { return filters_.size(); }
private:
// Provides buffer with a related index.
struct IndexedBuffer {
explicit IndexedBuffer(size_t size);
~IndexedBuffer();
std::vector<float> data;
int index = 0;
RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(IndexedBuffer);
};
ApmDataDumper* const data_dumper_;
const size_t filter_intra_lag_shift_;
std::vector<std::vector<float>> filters_;
std::vector<LagEstimate> lag_estimates_;
IndexedBuffer x_buffer_;
RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(MatchedFilter);
};
} // namespace webrtc
#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_

View File

@ -0,0 +1,66 @@
/*
* Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h"
#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h"
namespace webrtc {
MatchedFilterLagAggregator::MatchedFilterLagAggregator(
ApmDataDumper* data_dumper,
size_t num_lag_estimates)
: data_dumper_(data_dumper), lag_updates_in_a_row_(num_lag_estimates, 0) {
RTC_DCHECK(data_dumper);
RTC_DCHECK_LT(0, num_lag_estimates);
}
MatchedFilterLagAggregator::~MatchedFilterLagAggregator() = default;
rtc::Optional<size_t> MatchedFilterLagAggregator::Aggregate(
rtc::ArrayView<const MatchedFilter::LagEstimate> lag_estimates) {
RTC_DCHECK_EQ(lag_updates_in_a_row_.size(), lag_estimates.size());
// Count the number of lag updates in a row to ensure that only stable lags
// are taken into account.
for (size_t k = 0; k < lag_estimates.size(); ++k) {
lag_updates_in_a_row_[k] =
lag_estimates[k].updated ? lag_updates_in_a_row_[k] + 1 : 0;
}
// If available, choose the strongest lag estimate as the best one.
int best_lag_estimate_index = -1;
for (size_t k = 0; k < lag_estimates.size(); ++k) {
if (lag_updates_in_a_row_[k] > 10 && lag_estimates[k].reliable &&
(best_lag_estimate_index == -1 ||
lag_estimates[k].accuracy >
lag_estimates[best_lag_estimate_index].accuracy)) {
best_lag_estimate_index = k;
}
}
// TODO(peah): Remove this logging once all development is done.
data_dumper_->DumpRaw("aec3_echo_path_delay_estimator_best_index",
best_lag_estimate_index);
// Require the same lag to be detected 10 times in a row before considering
// it reliable.
if (best_lag_estimate_index >= 0) {
candidate_counter_ =
(candidate_ == lag_estimates[best_lag_estimate_index].lag)
? candidate_counter_ + 1
: 0;
candidate_ = lag_estimates[best_lag_estimate_index].lag;
}
return candidate_counter_ >= 10 ? rtc::Optional<size_t>(candidate_)
: rtc::Optional<size_t>();
}
} // namespace webrtc

View File

@ -0,0 +1,46 @@
/*
* Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_LAG_AGGREGATOR_H_
#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_LAG_AGGREGATOR_H_
#include <vector>
#include "webrtc/base/constructormagic.h"
#include "webrtc/base/optional.h"
#include "webrtc/modules/audio_processing/aec3/matched_filter.h"
namespace webrtc {
class ApmDataDumper;
// Aggregates lag estimates produced by the MatchedFilter class into a single
// reliable combined lag estimate.
class MatchedFilterLagAggregator {
public:
MatchedFilterLagAggregator(ApmDataDumper* data_dumper,
size_t num_lag_estimates);
~MatchedFilterLagAggregator();
// Aggregates the provided lag estimates.
rtc::Optional<size_t> Aggregate(
rtc::ArrayView<const MatchedFilter::LagEstimate> lag_estimates);
private:
ApmDataDumper* const data_dumper_;
std::vector<size_t> lag_updates_in_a_row_;
size_t candidate_ = 0;
size_t candidate_counter_ = 0;
RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(MatchedFilterLagAggregator);
};
} // namespace webrtc
#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_LAG_AGGREGATOR_H_

View File

@ -0,0 +1,192 @@
/*
* Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h"
#include <sstream>
#include <string>
#include <vector>
#include "webrtc/base/array_view.h"
#include "webrtc/modules/audio_processing/aec3/aec3_constants.h"
#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h"
#include "webrtc/test/gtest.h"
namespace webrtc {
namespace {
void VerifyNoAggregateOutputForRepeatedLagAggregation(
size_t num_repetitions,
rtc::ArrayView<const MatchedFilter::LagEstimate> lag_estimates,
MatchedFilterLagAggregator* aggregator) {
for (size_t k = 0; k < num_repetitions; ++k) {
EXPECT_FALSE(aggregator->Aggregate(lag_estimates));
}
}
constexpr size_t kThresholdForRequiredLagUpdatesInARow = 10;
constexpr size_t kThresholdForRequiredIdenticalLagAggregates = 10;
} // namespace
// Verifies that the most accurate lag estimate is chosen.
TEST(MatchedFilterLagAggregator, MostAccurateLagChosen) {
constexpr size_t kArtificialLag1 = 5;
constexpr size_t kArtificialLag2 = 10;
ApmDataDumper data_dumper(0);
std::vector<MatchedFilter::LagEstimate> lag_estimates(2);
MatchedFilterLagAggregator aggregator(&data_dumper, lag_estimates.size());
lag_estimates[0] =
MatchedFilter::LagEstimate(1.f, true, kArtificialLag1, true);
lag_estimates[1] =
MatchedFilter::LagEstimate(0.5f, true, kArtificialLag2, true);
VerifyNoAggregateOutputForRepeatedLagAggregation(
kThresholdForRequiredLagUpdatesInARow +
kThresholdForRequiredIdenticalLagAggregates,
lag_estimates, &aggregator);
rtc::Optional<size_t> aggregated_lag = aggregator.Aggregate(lag_estimates);
EXPECT_TRUE(aggregated_lag);
EXPECT_EQ(kArtificialLag1, *aggregated_lag);
lag_estimates[0] =
MatchedFilter::LagEstimate(0.5f, true, kArtificialLag1, true);
lag_estimates[1] =
MatchedFilter::LagEstimate(1.f, true, kArtificialLag2, true);
VerifyNoAggregateOutputForRepeatedLagAggregation(
kThresholdForRequiredIdenticalLagAggregates, lag_estimates, &aggregator);
aggregated_lag = aggregator.Aggregate(lag_estimates);
EXPECT_TRUE(aggregated_lag);
EXPECT_EQ(kArtificialLag2, *aggregated_lag);
}
// Verifies that varying lag estimates causes lag estimates to not be deemed
// reliable.
TEST(MatchedFilterLagAggregator,
LagEstimateInvarianceRequiredForAggregatedLag) {
constexpr size_t kArtificialLag1 = 5;
constexpr size_t kArtificialLag2 = 10;
ApmDataDumper data_dumper(0);
std::vector<MatchedFilter::LagEstimate> lag_estimates(1);
MatchedFilterLagAggregator aggregator(&data_dumper, lag_estimates.size());
lag_estimates[0] =
MatchedFilter::LagEstimate(1.f, true, kArtificialLag1, true);
VerifyNoAggregateOutputForRepeatedLagAggregation(
kThresholdForRequiredLagUpdatesInARow +
kThresholdForRequiredIdenticalLagAggregates,
lag_estimates, &aggregator);
rtc::Optional<size_t> aggregated_lag = aggregator.Aggregate(lag_estimates);
EXPECT_TRUE(aggregated_lag);
EXPECT_EQ(kArtificialLag1, *aggregated_lag);
lag_estimates[0] =
MatchedFilter::LagEstimate(1.f, true, kArtificialLag2, true);
VerifyNoAggregateOutputForRepeatedLagAggregation(
kThresholdForRequiredIdenticalLagAggregates, lag_estimates, &aggregator);
aggregated_lag = aggregator.Aggregate(lag_estimates);
EXPECT_TRUE(aggregated_lag);
EXPECT_EQ(kArtificialLag2, *aggregated_lag);
}
// Verifies that lag estimate updates are required to produce an updated lag
// aggregate.
TEST(MatchedFilterLagAggregator, LagEstimateUpdatesRequiredForAggregatedLag) {
constexpr size_t kArtificialLag1 = 5;
constexpr size_t kArtificialLag2 = 10;
ApmDataDumper data_dumper(0);
std::vector<MatchedFilter::LagEstimate> lag_estimates(1);
MatchedFilterLagAggregator aggregator(&data_dumper, lag_estimates.size());
lag_estimates[0] =
MatchedFilter::LagEstimate(1.f, true, kArtificialLag1, true);
VerifyNoAggregateOutputForRepeatedLagAggregation(
kThresholdForRequiredLagUpdatesInARow +
kThresholdForRequiredIdenticalLagAggregates,
lag_estimates, &aggregator);
rtc::Optional<size_t> aggregated_lag = aggregator.Aggregate(lag_estimates);
EXPECT_TRUE(aggregated_lag);
EXPECT_EQ(kArtificialLag1, *aggregated_lag);
lag_estimates[0] =
MatchedFilter::LagEstimate(1.f, true, kArtificialLag2, false);
for (size_t k = 0; k < kThresholdForRequiredLagUpdatesInARow +
kThresholdForRequiredIdenticalLagAggregates + 1;
++k) {
aggregated_lag = aggregator.Aggregate(lag_estimates);
EXPECT_TRUE(aggregated_lag);
EXPECT_EQ(kArtificialLag1, *aggregated_lag);
}
lag_estimates[0] =
MatchedFilter::LagEstimate(1.f, true, kArtificialLag2, true);
for (size_t k = 0; k < kThresholdForRequiredLagUpdatesInARow; ++k) {
aggregated_lag = aggregator.Aggregate(lag_estimates);
EXPECT_TRUE(aggregated_lag);
EXPECT_EQ(kArtificialLag1, *aggregated_lag);
}
VerifyNoAggregateOutputForRepeatedLagAggregation(
kThresholdForRequiredIdenticalLagAggregates, lag_estimates, &aggregator);
aggregated_lag = aggregator.Aggregate(lag_estimates);
EXPECT_TRUE(aggregated_lag);
EXPECT_EQ(kArtificialLag2, *aggregated_lag);
}
// Verifies that an aggregated lag is persistent if the lag estimates do not
// change and that an aggregated lag is not produced without gaining lag
// estimate confidence.
TEST(MatchedFilterLagAggregator, PersistentAggregatedLag) {
constexpr size_t kArtificialLag = 5;
ApmDataDumper data_dumper(0);
std::vector<MatchedFilter::LagEstimate> lag_estimates(1);
MatchedFilterLagAggregator aggregator(&data_dumper, lag_estimates.size());
lag_estimates[0] =
MatchedFilter::LagEstimate(1.f, true, kArtificialLag, true);
VerifyNoAggregateOutputForRepeatedLagAggregation(
kThresholdForRequiredLagUpdatesInARow +
kThresholdForRequiredIdenticalLagAggregates,
lag_estimates, &aggregator);
rtc::Optional<size_t> aggregated_lag = aggregator.Aggregate(lag_estimates);
EXPECT_TRUE(aggregated_lag);
EXPECT_EQ(kArtificialLag, *aggregated_lag);
aggregated_lag = aggregator.Aggregate(lag_estimates);
EXPECT_TRUE(aggregated_lag);
EXPECT_EQ(kArtificialLag, *aggregated_lag);
}
#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
// Verifies the check for correct number of lag estimates.
TEST(MatchedFilterLagAggregator, IncorrectNumberOfLagEstimates) {
ApmDataDumper data_dumper(0);
MatchedFilterLagAggregator aggregator(&data_dumper, 1);
std::vector<MatchedFilter::LagEstimate> lag_estimates(2);
EXPECT_DEATH(aggregator.Aggregate(lag_estimates), "");
}
// Verifies the check for non-zero number of lag estimates.
TEST(MatchedFilterLagAggregator, NonZeroLagEstimates) {
ApmDataDumper data_dumper(0);
EXPECT_DEATH(MatchedFilterLagAggregator(&data_dumper, 0), "");
}
// Verifies the check for non-null data dumper.
TEST(MatchedFilterLagAggregator, NullDataDumper) {
EXPECT_DEATH(MatchedFilterLagAggregator(nullptr, 1), "");
}
#endif
} // namespace webrtc

View File

@ -0,0 +1,190 @@
/*
* Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "webrtc/modules/audio_processing/aec3/matched_filter.h"
#include <algorithm>
#include <sstream>
#include <string>
#include "webrtc/modules/audio_processing/aec3/aec3_constants.h"
#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h"
#include "webrtc/modules/audio_processing/test/echo_canceller_test_tools.h"
#include "webrtc/test/gtest.h"
namespace webrtc {
namespace {
std::string ProduceDebugText(size_t delay) {
std::ostringstream ss;
ss << "Delay: " << delay;
return ss.str();
}
constexpr size_t kWindowSizeSubBlocks = 32;
constexpr size_t kAlignmentShiftSubBlocks = kWindowSizeSubBlocks * 3 / 4;
constexpr size_t kNumMatchedFilters = 4;
} // namespace
// Verifies that the matched filter produces proper lag estimates for
// artificially
// delayed signals.
TEST(MatchedFilter, LagEstimation) {
Random random_generator(42U);
std::array<float, kSubBlockSize> render;
std::array<float, kSubBlockSize> capture;
render.fill(0.f);
capture.fill(0.f);
ApmDataDumper data_dumper(0);
for (size_t delay_samples : {0, 64, 150, 200, 800, 1000}) {
SCOPED_TRACE(ProduceDebugText(delay_samples));
DelayBuffer<float> signal_delay_buffer(delay_samples);
MatchedFilter filter(&data_dumper, kWindowSizeSubBlocks, kNumMatchedFilters,
kAlignmentShiftSubBlocks);
// Analyze the correlation between render and capture.
for (size_t k = 0; k < (100 + delay_samples / kSubBlockSize); ++k) {
RandomizeSampleVector(&random_generator, render);
signal_delay_buffer.Delay(render, capture);
filter.Update(render, capture);
}
// Obtain the lag estimates.
auto lag_estimates = filter.GetLagEstimates();
// Find which lag estimate should be the most accurate.
rtc::Optional<size_t> expected_most_accurate_lag_estimate;
size_t alignment_shift_sub_blocks = 0;
for (size_t k = 0; k < kNumMatchedFilters; ++k) {
if ((alignment_shift_sub_blocks + kWindowSizeSubBlocks / 2) *
kSubBlockSize >
delay_samples) {
expected_most_accurate_lag_estimate = rtc::Optional<size_t>(k);
break;
}
alignment_shift_sub_blocks += kAlignmentShiftSubBlocks;
}
ASSERT_TRUE(expected_most_accurate_lag_estimate);
// Verify that the expected most accurate lag estimate is the most accurate
// estimate.
for (size_t k = 0; k < kNumMatchedFilters; ++k) {
if (k != *expected_most_accurate_lag_estimate) {
EXPECT_GT(lag_estimates[*expected_most_accurate_lag_estimate].accuracy,
lag_estimates[k].accuracy);
}
}
// Verify that all lag estimates are updated as expected for signals
// containing strong noise.
for (auto& le : lag_estimates) {
EXPECT_TRUE(le.updated);
}
// Verify that the expected most accurate lag estimate is reliable.
EXPECT_TRUE(lag_estimates[*expected_most_accurate_lag_estimate].reliable);
// Verify that the expected most accurate lag estimate is correct.
EXPECT_EQ(delay_samples,
lag_estimates[*expected_most_accurate_lag_estimate].lag);
}
}
// Verifies that the matched filter does not produce reliable and accurate
// estimates for uncorrelated render and capture signals.
TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) {
Random random_generator(42U);
std::array<float, kSubBlockSize> render;
std::array<float, kSubBlockSize> capture;
render.fill(0.f);
capture.fill(0.f);
ApmDataDumper data_dumper(0);
MatchedFilter filter(&data_dumper, kWindowSizeSubBlocks, kNumMatchedFilters,
kAlignmentShiftSubBlocks);
// Analyze the correlation between render and capture.
for (size_t k = 0; k < 100; ++k) {
RandomizeSampleVector(&random_generator, render);
RandomizeSampleVector(&random_generator, capture);
filter.Update(render, capture);
}
// Obtain the lag estimates.
auto lag_estimates = filter.GetLagEstimates();
EXPECT_EQ(kNumMatchedFilters, lag_estimates.size());
// Verify that no lag estimates are reliable.
for (auto& le : lag_estimates) {
EXPECT_FALSE(le.reliable);
}
}
// Verifies that the matched filter does not produce updated lag estimates for
// render signals of low level.
TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) {
Random random_generator(42U);
std::array<float, kSubBlockSize> render;
std::array<float, kSubBlockSize> capture;
render.fill(0.f);
capture.fill(0.f);
ApmDataDumper data_dumper(0);
MatchedFilter filter(&data_dumper, kWindowSizeSubBlocks, kNumMatchedFilters,
kAlignmentShiftSubBlocks);
// Analyze the correlation between render and capture.
for (size_t k = 0; k < 100; ++k) {
RandomizeSampleVector(&random_generator, render);
for (auto& render_k : render) {
render_k *= 149.f / 32767.f;
}
std::copy(render.begin(), render.end(), capture.begin());
filter.Update(render, capture);
}
// Obtain the lag estimates.
auto lag_estimates = filter.GetLagEstimates();
EXPECT_EQ(kNumMatchedFilters, lag_estimates.size());
// Verify that no lag estimates are updated and that no lag estimates are
// reliable.
for (auto& le : lag_estimates) {
EXPECT_FALSE(le.updated);
EXPECT_FALSE(le.reliable);
}
}
// Verifies that the correct number of lag estimates are produced for a certain
// number of alignment shifts.
TEST(MatchedFilter, NumberOfLagEstimates) {
ApmDataDumper data_dumper(0);
for (size_t num_matched_filters = 0; num_matched_filters < 10;
++num_matched_filters) {
MatchedFilter filter(&data_dumper, 32, num_matched_filters, 1);
EXPECT_EQ(num_matched_filters, filter.GetLagEstimates().size());
}
}
#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
// Verifies the check for non-zero windows size.
TEST(MatchedFilter, ZeroWindowSize) {
ApmDataDumper data_dumper(0);
EXPECT_DEATH(MatchedFilter(&data_dumper, 0, 1, 1), "");
}
// Verifies the check for non-null data dumper.
TEST(MatchedFilter, NullDataDumper) {
EXPECT_DEATH(MatchedFilter(nullptr, 1, 1, 1), "");
}
#endif
} // namespace webrtc

View File

@ -115,9 +115,8 @@ RenderDelayControllerImpl::RenderDelayControllerImpl(
max_delay_(render_delay_buffer.MaxDelay()),
delay_(render_delay_buffer.Delay()),
render_buffer_(render_delay_buffer.MaxApiJitter() + 1),
delay_estimator_(data_dumper_.get(), sample_rate_hz) {
RTC_DCHECK(sample_rate_hz == 8000 || sample_rate_hz == 16000 ||
sample_rate_hz == 32000 || sample_rate_hz == 48000);
delay_estimator_(data_dumper_.get()) {
RTC_DCHECK(ValidFullBandRate(sample_rate_hz));
}
RenderDelayControllerImpl::~RenderDelayControllerImpl() = default;

View File

@ -105,8 +105,7 @@ TEST(RenderDelayController, BasicApiCalls) {
// Verifies that the RenderDelayController is able to align the signals for
// simple timeshifts between the signals.
// TODO(peah): Activate the unittest once the required code has been landed.
TEST(RenderDelayController, DISABLED_Alignment) {
TEST(RenderDelayController, Alignment) {
Random random_generator(42U);
std::vector<float> render_block(kBlockSize, 0.f);
std::vector<float> capture_block(kBlockSize, 0.f);
@ -148,8 +147,7 @@ TEST(RenderDelayController, DISABLED_Alignment) {
// Verifies that the RenderDelayController is able to align the signals for
// simple timeshifts between the signals when there is jitter in the API calls.
// TODO(peah): Activate the unittest once the required code has been landed.
TEST(RenderDelayController, DISABLED_AlignmentWithJitter) {
TEST(RenderDelayController, AlignmentWithJitter) {
Random random_generator(42U);
std::vector<float> render_block(kBlockSize, 0.f);
std::vector<float> capture_block(kBlockSize, 0.f);