Add parameterization for three multi channel AEC3 unit tests

Bug: webrtc:11295
Change-Id: I478aa02908c494cf9609db00021438a59a132b66
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/167202
Commit-Queue: Sam Zackrisson <saza@webrtc.org>
Reviewed-by: Per Åhgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#30370}
This commit is contained in:
Sam Zackrisson 2020-01-24 12:55:17 +01:00 committed by Commit Bot
parent 159c414ff8
commit b18c4eb0a9
10 changed files with 849 additions and 795 deletions

View File

@ -51,12 +51,21 @@ std::string ProduceDebugText(size_t num_render_channels, size_t delay) {
} // namespace
class AdaptiveFirFilterOneTwoFourEightRenderChannels
: public ::testing::Test,
public ::testing::WithParamInterface<size_t> {};
INSTANTIATE_TEST_SUITE_P(MultiChannel,
AdaptiveFirFilterOneTwoFourEightRenderChannels,
::testing::Values(1, 2, 4, 8));
#if defined(WEBRTC_HAS_NEON)
// Verifies that the optimized methods for filter adaptation are similar to
// their reference counterparts.
TEST(AdaptiveFirFilter, FilterAdaptationNeonOptimizations) {
TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels,
FilterAdaptationNeonOptimizations) {
const size_t num_render_channels = GetParam();
for (size_t num_partitions : {2, 5, 12, 30, 50}) {
for (size_t num_render_channels : {1, 2, 4, 8}) {
constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
@ -128,18 +137,17 @@ TEST(AdaptiveFirFilter, FilterAdaptationNeonOptimizations) {
}
}
}
}
// Verifies that the optimized method for frequency response computation is
// bitexact to the reference counterpart.
TEST(AdaptiveFirFilter, ComputeFrequencyResponseNeonOptimization) {
TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels,
ComputeFrequencyResponseNeonOptimization) {
const size_t num_render_channels = GetParam();
for (size_t num_partitions : {2, 5, 12, 30, 50}) {
for (size_t num_render_channels : {1, 2, 4, 8}) {
std::vector<std::vector<FftData>> H(
num_partitions, std::vector<FftData>(num_render_channels));
std::vector<std::array<float, kFftLengthBy2Plus1>> H2(num_partitions);
std::vector<std::array<float, kFftLengthBy2Plus1>> H2_Neon(
num_partitions);
std::vector<std::array<float, kFftLengthBy2Plus1>> H2_Neon(num_partitions);
for (size_t p = 0; p < num_partitions; ++p) {
for (size_t ch = 0; ch < num_render_channels; ++ch) {
@ -160,28 +168,28 @@ TEST(AdaptiveFirFilter, ComputeFrequencyResponseNeonOptimization) {
}
}
}
}
#endif
#if defined(WEBRTC_ARCH_X86_FAMILY)
// Verifies that the optimized methods for filter adaptation are bitexact to
// their reference counterparts.
TEST(AdaptiveFirFilter, FilterAdaptationSse2Optimizations) {
TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels,
FilterAdaptationSse2Optimizations) {
const size_t num_render_channels = GetParam();
constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0);
if (use_sse2) {
for (size_t num_partitions : {2, 5, 12, 30, 50}) {
for (size_t num_render_channels : {1, 2, 4, 8}) {
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz,
num_render_channels));
Random random_generator(42U);
std::vector<std::vector<std::vector<float>>> x(
kNumBands,
std::vector<std::vector<float>>(
num_render_channels, std::vector<float>(kBlockSize, 0.f)));
std::vector<std::vector<float>>(num_render_channels,
std::vector<float>(kBlockSize, 0.f)));
FftData S_C;
FftData S_Sse2;
FftData G;
@ -237,15 +245,15 @@ TEST(AdaptiveFirFilter, FilterAdaptationSse2Optimizations) {
}
}
}
}
// Verifies that the optimized method for frequency response computation is
// bitexact to the reference counterpart.
TEST(AdaptiveFirFilter, ComputeFrequencyResponseSse2Optimization) {
TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels,
ComputeFrequencyResponseSse2Optimization) {
const size_t num_render_channels = GetParam();
bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0);
if (use_sse2) {
for (size_t num_partitions : {2, 5, 12, 30, 50}) {
for (size_t num_render_channels : {1, 2, 4, 8}) {
std::vector<std::vector<FftData>> H(
num_partitions, std::vector<FftData>(num_render_channels));
std::vector<std::array<float, kFftLengthBy2Plus1>> H2(num_partitions);
@ -272,19 +280,18 @@ TEST(AdaptiveFirFilter, ComputeFrequencyResponseSse2Optimization) {
}
}
}
}
#endif
#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
// Verifies that the check for non-null data dumper works.
TEST(AdaptiveFirFilter, NullDataDumper) {
TEST(AdaptiveFirFilterTest, NullDataDumper) {
EXPECT_DEATH(AdaptiveFirFilter(9, 9, 250, 1, DetectOptimization(), nullptr),
"");
}
// Verifies that the check for non-null filter output works.
TEST(AdaptiveFirFilter, NullFilterOutput) {
TEST(AdaptiveFirFilterTest, NullFilterOutput) {
ApmDataDumper data_dumper(42);
AdaptiveFirFilter filter(9, 9, 250, 1, DetectOptimization(), &data_dumper);
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
@ -297,7 +304,7 @@ TEST(AdaptiveFirFilter, NullFilterOutput) {
// Verifies that the filter statistics can be accessed when filter statistics
// are turned on.
TEST(AdaptiveFirFilter, FilterStatisticsAccess) {
TEST(AdaptiveFirFilterTest, FilterStatisticsAccess) {
ApmDataDumper data_dumper(42);
Aec3Optimization optimization = DetectOptimization();
AdaptiveFirFilter filter(9, 9, 250, 1, optimization, &data_dumper);
@ -314,7 +321,7 @@ TEST(AdaptiveFirFilter, FilterStatisticsAccess) {
}
// Verifies that the filter size if correctly repported.
TEST(AdaptiveFirFilter, FilterSize) {
TEST(AdaptiveFirFilterTest, FilterSize) {
ApmDataDumper data_dumper(42);
for (size_t filter_size = 1; filter_size < 5; ++filter_size) {
AdaptiveFirFilter filter(filter_size, filter_size, 250, 1,
@ -323,23 +330,32 @@ TEST(AdaptiveFirFilter, FilterSize) {
}
}
class AdaptiveFirFilterMultiChannel
: public ::testing::Test,
public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
INSTANTIATE_TEST_SUITE_P(MultiChannel,
AdaptiveFirFilterMultiChannel,
::testing::Combine(::testing::Values(1, 4),
::testing::Values(1, 8)));
// Verifies that the filter is being able to properly filter a signal and to
// adapt its coefficients.
TEST(AdaptiveFirFilter, FilterAndAdapt) {
TEST_P(AdaptiveFirFilterMultiChannel, FilterAndAdapt) {
const size_t num_render_channels = std::get<0>(GetParam());
const size_t num_capture_channels = std::get<1>(GetParam());
constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
constexpr size_t kNumBlocksToProcessPerRenderChannel = 1000;
for (size_t num_capture_channels : {1, 4}) {
for (size_t num_render_channels : {1, 8}) {
ApmDataDumper data_dumper(42);
EchoCanceller3Config config;
if (num_render_channels == 33) {
config.filter.main = {13, 0.00005f, 0.0005f, 0.0001f, 2.f, 20075344.f};
config.filter.shadow = {13, 0.1f, 20075344.f};
config.filter.main_initial = {12, 0.005f, 0.5f,
0.001f, 2.f, 20075344.f};
config.filter.main_initial = {12, 0.005f, 0.5f, 0.001f, 2.f, 20075344.f};
config.filter.shadow_initial = {12, 0.7f, 20075344.f};
}
@ -348,8 +364,7 @@ TEST(AdaptiveFirFilter, FilterAndAdapt) {
config.filter.config_change_duration_blocks, num_render_channels,
DetectOptimization(), &data_dumper);
std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> H2(
num_capture_channels,
std::vector<std::array<float, kFftLengthBy2Plus1>>(
num_capture_channels, std::vector<std::array<float, kFftLengthBy2Plus1>>(
filter.max_filter_size_partitions(),
std::array<float, kFftLengthBy2Plus1>()));
std::vector<std::vector<float>> h(
@ -359,15 +374,13 @@ TEST(AdaptiveFirFilter, FilterAndAdapt) {
Aec3Fft fft;
config.delay.default_delay = 1;
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(config, kSampleRateHz,
num_render_channels));
RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels));
ShadowFilterUpdateGain gain(config.filter.shadow,
config.filter.config_change_duration_blocks);
Random random_generator(42U);
std::vector<std::vector<std::vector<float>>> x(
kNumBands,
std::vector<std::vector<float>>(num_render_channels,
std::vector<float>(kBlockSize, 0.f)));
kNumBands, std::vector<std::vector<float>>(
num_render_channels, std::vector<float>(kBlockSize, 0.f)));
std::vector<float> n(kBlockSize, 0.f);
std::vector<float> y(kBlockSize, 0.f);
AecState aec_state(EchoCanceller3Config{}, num_capture_channels);
@ -379,8 +392,7 @@ TEST(AdaptiveFirFilter, FilterAndAdapt) {
FftData S;
FftData G;
FftData E;
std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(
num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> E2_main(
num_capture_channels);
std::array<float, kFftLengthBy2Plus1> E2_shadow;
@ -452,9 +464,8 @@ TEST(AdaptiveFirFilter, FilterAndAdapt) {
std::transform(y.begin(), y.end(), s_scratch.begin() + kFftLengthBy2,
e.begin(),
[&](float a, float b) { return a - b * kScale; });
std::for_each(e.begin(), e.end(), [](float& a) {
a = rtc::SafeClamp(a, -32768.f, 32767.f);
});
std::for_each(e.begin(), e.end(),
[](float& a) { a = rtc::SafeClamp(a, -32768.f, 32767.f); });
fft.ZeroPaddedFft(e, Aec3Fft::Window::kRectangular, &E);
for (auto& o : output) {
for (size_t k = 0; k < kBlockSize; ++k) {
@ -479,7 +490,6 @@ TEST(AdaptiveFirFilter, FilterAndAdapt) {
std::inner_product(y.begin(), y.end(), y.begin(), 0.f));
}
}
}
}
} // namespace aec3
} // namespace webrtc

View File

@ -18,13 +18,6 @@
namespace webrtc {
namespace {
std::string ProduceDebugText(size_t num_render_channels,
size_t num_capture_channels) {
rtc::StringBuilder ss;
ss << "Render channels: " << num_render_channels;
ss << ", Capture channels: " << num_capture_channels;
return ss.Release();
}
void RunNormalUsageTest(size_t num_render_channels,
size_t num_capture_channels) {
@ -232,15 +225,21 @@ void RunNormalUsageTest(size_t num_render_channels,
} // namespace
class AecStateMultiChannel
: public ::testing::Test,
public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
INSTANTIATE_TEST_SUITE_P(MultiChannel,
AecStateMultiChannel,
::testing::Combine(::testing::Values(1, 2, 8),
::testing::Values(1, 2, 8)));
// Verify the general functionality of AecState
TEST(AecState, NormalUsage) {
for (size_t num_render_channels : {1, 2, 8}) {
for (size_t num_capture_channels : {1, 2, 8}) {
SCOPED_TRACE(ProduceDebugText(num_render_channels, num_capture_channels));
TEST_P(AecStateMultiChannel, NormalUsage) {
const size_t num_render_channels = std::get<0>(GetParam());
const size_t num_capture_channels = std::get<1>(GetParam());
RunNormalUsageTest(num_render_channels, num_capture_channels);
}
}
}
// Verifies the delay for a converged filter is correctly identified.
TEST(AecState, ConvergedFilterDelay) {

View File

@ -34,19 +34,26 @@ std::string ProduceDebugText(size_t delay, size_t down_sampling_factor) {
} // namespace
class EchoPathDelayEstimatorMultiChannel
: public ::testing::Test,
public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
INSTANTIATE_TEST_SUITE_P(MultiChannel,
EchoPathDelayEstimatorMultiChannel,
::testing::Combine(::testing::Values(1, 2, 3, 6, 8),
::testing::Values(1, 2, 4)));
// Verifies that the basic API calls work.
TEST(EchoPathDelayEstimator, BasicApiCalls) {
TEST_P(EchoPathDelayEstimatorMultiChannel, BasicApiCalls) {
const size_t num_render_channels = std::get<0>(GetParam());
const size_t num_capture_channels = std::get<1>(GetParam());
constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
for (size_t num_capture_channels : {1, 2, 4}) {
for (size_t num_render_channels : {1, 2, 3, 6, 8}) {
ApmDataDumper data_dumper(0);
EchoCanceller3Config config;
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(config, kSampleRateHz,
num_render_channels));
EchoPathDelayEstimator estimator(&data_dumper, config,
num_capture_channels);
RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels));
EchoPathDelayEstimator estimator(&data_dumper, config, num_capture_channels);
std::vector<std::vector<std::vector<float>>> render(
kNumBands, std::vector<std::vector<float>>(
num_render_channels, std::vector<float>(kBlockSize)));
@ -54,10 +61,8 @@ TEST(EchoPathDelayEstimator, BasicApiCalls) {
std::vector<float>(kBlockSize));
for (size_t k = 0; k < 100; ++k) {
render_delay_buffer->Insert(render);
estimator.EstimateDelay(
render_delay_buffer->GetDownsampledRenderBuffer(), capture);
}
}
estimator.EstimateDelay(render_delay_buffer->GetDownsampledRenderBuffer(),
capture);
}
}

View File

@ -26,7 +26,6 @@
namespace webrtc {
namespace {
std::string ProduceDebugText(int sample_rate_hz) {
rtc::StringBuilder ss;
ss << "Sample rate: " << sample_rate_hz;
@ -41,43 +40,48 @@ std::string ProduceDebugText(int sample_rate_hz, int delay) {
} // namespace
class EchoRemoverMultiChannel
: public ::testing::Test,
public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
INSTANTIATE_TEST_SUITE_P(MultiChannel,
EchoRemoverMultiChannel,
::testing::Combine(::testing::Values(1, 2, 8),
::testing::Values(1, 2, 8)));
// Verifies the basic API call sequence
TEST(EchoRemover, BasicApiCalls) {
TEST_P(EchoRemoverMultiChannel, BasicApiCalls) {
const size_t num_render_channels = std::get<0>(GetParam());
const size_t num_capture_channels = std::get<1>(GetParam());
absl::optional<DelayEstimate> delay_estimate;
for (auto rate : {16000, 32000, 48000}) {
for (size_t num_render_channels : {1, 2, 8}) {
for (size_t num_capture_channels : {1, 2, 8}) {
SCOPED_TRACE(ProduceDebugText(rate));
std::unique_ptr<EchoRemover> remover(
EchoRemover::Create(EchoCanceller3Config(), rate,
num_render_channels, num_capture_channels));
std::unique_ptr<RenderDelayBuffer> render_buffer(
RenderDelayBuffer::Create(EchoCanceller3Config(), rate,
num_render_channels));
EchoRemover::Create(EchoCanceller3Config(), rate, num_render_channels,
num_capture_channels));
std::unique_ptr<RenderDelayBuffer> render_buffer(RenderDelayBuffer::Create(
EchoCanceller3Config(), rate, num_render_channels));
std::vector<std::vector<std::vector<float>>> render(
NumBandsForRate(rate),
std::vector<std::vector<float>>(
num_render_channels, std::vector<float>(kBlockSize, 0.f)));
std::vector<std::vector<float>>(num_render_channels,
std::vector<float>(kBlockSize, 0.f)));
std::vector<std::vector<std::vector<float>>> capture(
NumBandsForRate(rate),
std::vector<std::vector<float>>(
num_capture_channels, std::vector<float>(kBlockSize, 0.f)));
std::vector<std::vector<float>>(num_capture_channels,
std::vector<float>(kBlockSize, 0.f)));
for (size_t k = 0; k < 100; ++k) {
EchoPathVariability echo_path_variability(
k % 3 == 0 ? true : false,
k % 5 == 0
? EchoPathVariability::DelayAdjustment::kNewDetectedDelay
k % 5 == 0 ? EchoPathVariability::DelayAdjustment::kNewDetectedDelay
: EchoPathVariability::DelayAdjustment::kNone,
false);
render_buffer->Insert(render);
render_buffer->PrepareCaptureProcessing();
remover->ProcessCapture(
echo_path_variability, k % 2 == 0 ? true : false, delay_estimate,
render_buffer->GetRenderBuffer(), nullptr, &capture);
}
}
remover->ProcessCapture(echo_path_variability, k % 2 == 0 ? true : false,
delay_estimate, render_buffer->GetRenderBuffer(),
nullptr, &capture);
}
}
}

View File

@ -34,18 +34,25 @@ void VerifyErl(const std::array<float, kFftLengthBy2Plus1>& erl,
} // namespace
class ErlEstimatorMultiChannel
: public ::testing::Test,
public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
INSTANTIATE_TEST_SUITE_P(MultiChannel,
ErlEstimatorMultiChannel,
::testing::Combine(::testing::Values(1, 2, 8),
::testing::Values(1, 2, 8)));
// Verifies that the correct ERL estimates are achieved.
TEST(ErlEstimator, Estimates) {
for (size_t num_render_channels : {1, 2, 8}) {
for (size_t num_capture_channels : {1, 2, 8}) {
TEST_P(ErlEstimatorMultiChannel, Estimates) {
const size_t num_render_channels = std::get<0>(GetParam());
const size_t num_capture_channels = std::get<1>(GetParam());
SCOPED_TRACE(ProduceDebugText(num_render_channels, num_capture_channels));
std::vector<std::array<float, kFftLengthBy2Plus1>> X2(
num_render_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> X2(num_render_channels);
for (auto& X2_ch : X2) {
X2_ch.fill(0.f);
}
std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(
num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels);
for (auto& Y2_ch : Y2) {
Y2_ch.fill(0.f);
}
@ -94,7 +101,4 @@ TEST(ErlEstimator, Estimates) {
}
VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 1000.f);
}
}
}
} // namespace webrtc

View File

@ -16,12 +16,12 @@
#include "modules/audio_processing/aec3/render_delay_buffer.h"
#include "modules/audio_processing/aec3/spectrum_buffer.h"
#include "rtc_base/random.h"
#include "rtc_base/strings/string_builder.h"
#include "test/gtest.h"
namespace webrtc {
namespace {
constexpr int kLowFrequencyLimit = kFftLengthBy2 / 2;
constexpr float kTrueErle = 10.f;
constexpr float kTrueErleOnsets = 1.0f;
@ -129,37 +129,40 @@ void GetFilterFreq(
} // namespace
TEST(ErleEstimator, VerifyErleIncreaseAndHold) {
class ErleEstimatorMultiChannel
: public ::testing::Test,
public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
INSTANTIATE_TEST_SUITE_P(MultiChannel,
ErleEstimatorMultiChannel,
::testing::Combine(::testing::Values(1, 2, 4, 8),
::testing::Values(1, 2, 8)));
TEST_P(ErleEstimatorMultiChannel, VerifyErleIncreaseAndHold) {
const size_t num_render_channels = std::get<0>(GetParam());
const size_t num_capture_channels = std::get<1>(GetParam());
constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
for (size_t num_render_channels : {1, 2, 4, 8}) {
for (size_t num_capture_channels : {1, 2, 4}) {
std::array<float, kFftLengthBy2Plus1> X2;
std::vector<std::array<float, kFftLengthBy2Plus1>> E2(
num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(
num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> E2(num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels);
std::vector<bool> converged_filters(num_capture_channels, true);
EchoCanceller3Config config;
config.erle.onset_detection = true;
std::vector<std::vector<std::vector<float>>> x(
kNumBands,
std::vector<std::vector<float>>(num_render_channels,
std::vector<float>(kBlockSize, 0.f)));
kNumBands, std::vector<std::vector<float>>(
num_render_channels, std::vector<float>(kBlockSize, 0.f)));
std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>>
filter_frequency_response(
config.filter.main.length_blocks,
std::vector<std::array<float, kFftLengthBy2Plus1>>(
num_capture_channels));
std::vector<std::array<float, kFftLengthBy2Plus1>>(num_capture_channels));
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(config, kSampleRateHz,
num_render_channels));
RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels));
GetFilterFreq(config.delay.delay_headroom_samples,
filter_frequency_response);
GetFilterFreq(config.delay.delay_headroom_samples, filter_frequency_response);
ErleEstimator estimator(0, config, num_capture_channels);
@ -167,14 +170,13 @@ TEST(ErleEstimator, VerifyErleIncreaseAndHold) {
render_delay_buffer->Insert(x);
render_delay_buffer->PrepareCaptureProcessing();
// Verifies that the ERLE estimate is properly increased to higher values.
FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), kTrueErle, &X2,
E2, Y2);
FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), kTrueErle, &X2, E2,
Y2);
for (size_t k = 0; k < 200; ++k) {
render_delay_buffer->Insert(x);
render_delay_buffer->PrepareCaptureProcessing();
estimator.Update(*render_delay_buffer->GetRenderBuffer(),
filter_frequency_response, X2, Y2, E2,
converged_filters);
filter_frequency_response, X2, Y2, E2, converged_filters);
}
VerifyErle(estimator.Erle(), std::pow(2.f, estimator.FullbandErleLog2()),
config.erle.max_l, config.erle.max_h);
@ -186,44 +188,35 @@ TEST(ErleEstimator, VerifyErleIncreaseAndHold) {
render_delay_buffer->Insert(x);
render_delay_buffer->PrepareCaptureProcessing();
estimator.Update(*render_delay_buffer->GetRenderBuffer(),
filter_frequency_response, X2, Y2, E2,
converged_filters);
filter_frequency_response, X2, Y2, E2, converged_filters);
}
VerifyErle(estimator.Erle(), std::pow(2.f, estimator.FullbandErleLog2()),
config.erle.max_l, config.erle.max_h);
}
}
}
TEST(ErleEstimator, VerifyErleTrackingOnOnsets) {
TEST_P(ErleEstimatorMultiChannel, VerifyErleTrackingOnOnsets) {
const size_t num_render_channels = std::get<0>(GetParam());
const size_t num_capture_channels = std::get<1>(GetParam());
constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
for (size_t num_render_channels : {1, 2, 4, 8}) {
for (size_t num_capture_channels : {1, 2, 4}) {
std::array<float, kFftLengthBy2Plus1> X2;
std::vector<std::array<float, kFftLengthBy2Plus1>> E2(
num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(
num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> E2(num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels);
std::vector<bool> converged_filters(num_capture_channels, true);
EchoCanceller3Config config;
config.erle.onset_detection = true;
std::vector<std::vector<std::vector<float>>> x(
kNumBands,
std::vector<std::vector<float>>(num_render_channels,
std::vector<float>(kBlockSize, 0.f)));
kNumBands, std::vector<std::vector<float>>(
num_render_channels, std::vector<float>(kBlockSize, 0.f)));
std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>>
filter_frequency_response(
config.filter.main.length_blocks,
std::vector<std::array<float, kFftLengthBy2Plus1>>(
num_capture_channels));
std::vector<std::array<float, kFftLengthBy2Plus1>>(num_capture_channels));
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(config, kSampleRateHz,
num_render_channels));
RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels));
GetFilterFreq(config.delay.delay_headroom_samples,
filter_frequency_response);
GetFilterFreq(config.delay.delay_headroom_samples, filter_frequency_response);
ErleEstimator estimator(/*startup_phase_length_blocks=*/0, config,
num_capture_channels);
@ -233,8 +226,8 @@ TEST(ErleEstimator, VerifyErleTrackingOnOnsets) {
render_delay_buffer->PrepareCaptureProcessing();
for (size_t burst = 0; burst < 20; ++burst) {
FormFarendFrame(*render_delay_buffer->GetRenderBuffer(),
kTrueErleOnsets, &X2, E2, Y2);
FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), kTrueErleOnsets,
&X2, E2, Y2);
for (size_t k = 0; k < 10; ++k) {
render_delay_buffer->Insert(x);
render_delay_buffer->PrepareCaptureProcessing();
@ -242,8 +235,8 @@ TEST(ErleEstimator, VerifyErleTrackingOnOnsets) {
filter_frequency_response, X2, Y2, E2,
converged_filters);
}
FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), kTrueErle, &X2,
E2, Y2);
FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), kTrueErle, &X2, E2,
Y2);
for (size_t k = 0; k < 200; ++k) {
render_delay_buffer->Insert(x);
render_delay_buffer->PrepareCaptureProcessing();
@ -264,15 +257,12 @@ TEST(ErleEstimator, VerifyErleTrackingOnOnsets) {
FormNearendFrame(&x, &X2, E2, Y2);
for (size_t k = 0; k < 1000; k++) {
estimator.Update(*render_delay_buffer->GetRenderBuffer(),
filter_frequency_response, X2, Y2, E2,
converged_filters);
filter_frequency_response, X2, Y2, E2, converged_filters);
}
// Verifies that during ne activity, Erle converges to the Erle for
// onsets.
VerifyErle(estimator.Erle(), std::pow(2.f, estimator.FullbandErleLog2()),
config.erle.min, config.erle.min);
}
}
}
} // namespace webrtc

View File

@ -16,13 +16,23 @@
#include "modules/audio_processing/aec3/render_delay_buffer.h"
#include "modules/audio_processing/test/echo_canceller_test_tools.h"
#include "rtc_base/random.h"
#include "rtc_base/strings/string_builder.h"
#include "test/gtest.h"
namespace webrtc {
TEST(ResidualEchoEstimator, BasicTest) {
for (size_t num_render_channels : {1, 2, 4}) {
for (size_t num_capture_channels : {1, 2, 4}) {
class ResidualEchoEstimatorMultiChannel
: public ::testing::Test,
public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
INSTANTIATE_TEST_SUITE_P(MultiChannel,
ResidualEchoEstimatorMultiChannel,
::testing::Combine(::testing::Values(1, 2, 4),
::testing::Values(1, 2, 4)));
TEST_P(ResidualEchoEstimatorMultiChannel, BasicTest) {
const size_t num_render_channels = std::get<0>(GetParam());
const size_t num_capture_channels = std::get<1>(GetParam());
constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
@ -30,21 +40,17 @@ TEST(ResidualEchoEstimator, BasicTest) {
ResidualEchoEstimator estimator(config, num_render_channels);
AecState aec_state(config, num_capture_channels);
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(config, kSampleRateHz,
num_render_channels));
RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels));
std::vector<std::array<float, kFftLengthBy2Plus1>> E2_main(
num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> S2_linear(
num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(
num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> R2(
num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> R2(num_capture_channels);
std::vector<std::vector<std::vector<float>>> x(
kNumBands,
std::vector<std::vector<float>>(num_render_channels,
std::vector<float>(kBlockSize, 0.f)));
kNumBands, std::vector<std::vector<float>>(
num_render_channels, std::vector<float>(kBlockSize, 0.f)));
std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> H2(
num_capture_channels,
std::vector<std::array<float, kFftLengthBy2Plus1>>(10));
@ -63,8 +69,8 @@ TEST(ResidualEchoEstimator, BasicTest) {
std::vector<std::vector<float>> h(
num_capture_channels,
std::vector<float>(
GetTimeDomainLength(config.filter.main.length_blocks), 0.f));
std::vector<float>(GetTimeDomainLength(config.filter.main.length_blocks),
0.f));
for (auto& subtractor_output : output) {
subtractor_output.Reset();
@ -97,7 +103,5 @@ TEST(ResidualEchoEstimator, BasicTest) {
S2_linear, Y2, R2);
}
}
}
}
} // namespace webrtc

View File

@ -27,7 +27,6 @@
namespace webrtc {
namespace {
// Method for performing the simulations needed to test the main filter update
// gain functionality.
void RunFilterUpdateTest(int num_blocks_to_process,
@ -153,12 +152,21 @@ TEST(ShadowFilterUpdateGain, NullDataOutputGain) {
#endif
class ShadowFilterUpdateGainOneTwoEightRenderChannels
: public ::testing::Test,
public ::testing::WithParamInterface<size_t> {};
INSTANTIATE_TEST_SUITE_P(MultiChannel,
ShadowFilterUpdateGainOneTwoEightRenderChannels,
::testing::Values(1, 2, 8));
// Verifies that the gain formed causes the filter using it to converge.
TEST(ShadowFilterUpdateGain, GainCausesFilterToConverge) {
TEST_P(ShadowFilterUpdateGainOneTwoEightRenderChannels,
GainCausesFilterToConverge) {
const size_t num_render_channels = GetParam();
std::vector<int> blocks_with_echo_path_changes;
std::vector<int> blocks_with_saturation;
for (size_t num_render_channels : {1, 2, 8}) {
for (size_t filter_length_blocks : {12, 20, 30}) {
for (size_t delay_samples : {0, 64, 150, 200, 301}) {
SCOPED_TRACE(ProduceDebugText(delay_samples, filter_length_blocks));
@ -168,14 +176,13 @@ TEST(ShadowFilterUpdateGain, GainCausesFilterToConverge) {
FftData G;
RunFilterUpdateTest(5000, delay_samples, num_render_channels,
filter_length_blocks, blocks_with_saturation, &e,
&y, &G);
filter_length_blocks, blocks_with_saturation, &e, &y,
&G);
// Verify that the main filter is able to perform well.
// Use different criteria to take overmodelling into account.
if (filter_length_blocks == 12) {
EXPECT_LT(
1000 * std::inner_product(e.begin(), e.end(), e.begin(), 0.f),
EXPECT_LT(1000 * std::inner_product(e.begin(), e.end(), e.begin(), 0.f),
std::inner_product(y.begin(), y.end(), y.begin(), 0.f));
} else {
EXPECT_LT(std::inner_product(e.begin(), e.end(), e.begin(), 0.f),
@ -184,12 +191,50 @@ TEST(ShadowFilterUpdateGain, GainCausesFilterToConverge) {
}
}
}
// Verifies that the gain is zero when there is saturation.
TEST_P(ShadowFilterUpdateGainOneTwoEightRenderChannels, SaturationBehavior) {
const size_t num_render_channels = GetParam();
std::vector<int> blocks_with_echo_path_changes;
std::vector<int> blocks_with_saturation;
for (int k = 99; k < 200; ++k) {
blocks_with_saturation.push_back(k);
}
for (size_t filter_length_blocks : {12, 20, 30}) {
SCOPED_TRACE(ProduceDebugText(filter_length_blocks));
std::array<float, kBlockSize> e;
std::array<float, kBlockSize> y;
FftData G_a;
FftData G_a_ref;
G_a_ref.re.fill(0.f);
G_a_ref.im.fill(0.f);
RunFilterUpdateTest(100, 65, num_render_channels, filter_length_blocks,
blocks_with_saturation, &e, &y, &G_a);
EXPECT_EQ(G_a_ref.re, G_a.re);
EXPECT_EQ(G_a_ref.im, G_a.im);
}
}
class ShadowFilterUpdateGainOneTwoFourRenderChannels
: public ::testing::Test,
public ::testing::WithParamInterface<size_t> {};
INSTANTIATE_TEST_SUITE_P(
MultiChannel,
ShadowFilterUpdateGainOneTwoFourRenderChannels,
::testing::Values(1, 2, 4),
[](const ::testing::TestParamInfo<
ShadowFilterUpdateGainOneTwoFourRenderChannels::ParamType>& info) {
return (rtc::StringBuilder() << "Render" << info.param).str();
});
// Verifies that the magnitude of the gain on average decreases for a
// persistently exciting signal.
TEST(ShadowFilterUpdateGain, DecreasingGain) {
for (size_t num_render_channels : {1, 2, 4}) {
TEST_P(ShadowFilterUpdateGainOneTwoFourRenderChannels, DecreasingGain) {
const size_t num_render_channels = GetParam();
for (size_t filter_length_blocks : {12, 20, 30}) {
SCOPED_TRACE(ProduceDebugText(filter_length_blocks));
std::vector<int> blocks_with_echo_path_changes;
@ -222,33 +267,4 @@ TEST(ShadowFilterUpdateGain, DecreasingGain) {
std::accumulate(G_c_power.begin(), G_c_power.end(), 0.));
}
}
}
// Verifies that the gain is zero when there is saturation.
TEST(ShadowFilterUpdateGain, SaturationBehavior) {
std::vector<int> blocks_with_echo_path_changes;
std::vector<int> blocks_with_saturation;
for (int k = 99; k < 200; ++k) {
blocks_with_saturation.push_back(k);
}
for (size_t num_render_channels : {1, 2, 8}) {
for (size_t filter_length_blocks : {12, 20, 30}) {
SCOPED_TRACE(ProduceDebugText(filter_length_blocks));
std::array<float, kBlockSize> e;
std::array<float, kBlockSize> y;
FftData G_a;
FftData G_a_ref;
G_a_ref.re.fill(0.f);
G_a_ref.im.fill(0.f);
RunFilterUpdateTest(100, 65, num_render_channels, filter_length_blocks,
blocks_with_saturation, &e, &y, &G_a);
EXPECT_EQ(G_a_ref.re, G_a.re);
EXPECT_EQ(G_a_ref.im, G_a.im);
}
}
}
} // namespace webrtc

View File

@ -138,13 +138,21 @@ void TestInputs::UpdateCurrentPowerSpectra() {
} // namespace
TEST(SignalDependentErleEstimator, SweepSettings) {
for (size_t num_render_channels : {1, 2, 4}) {
for (size_t num_capture_channels : {1, 2, 4}) {
class SignalDependentErleEstimatorMultiChannel
: public ::testing::Test,
public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
INSTANTIATE_TEST_SUITE_P(MultiChannel,
SignalDependentErleEstimatorMultiChannel,
::testing::Combine(::testing::Values(1, 2, 4),
::testing::Values(1, 2, 4)));
TEST_P(SignalDependentErleEstimatorMultiChannel, SweepSettings) {
const size_t num_render_channels = std::get<0>(GetParam());
const size_t num_capture_channels = std::get<1>(GetParam());
EchoCanceller3Config cfg;
size_t max_length_blocks = 50;
for (size_t blocks = 1; blocks < max_length_blocks;
blocks = blocks + 10) {
for (size_t blocks = 1; blocks < max_length_blocks; blocks = blocks + 10) {
for (size_t delay_headroom = 0; delay_headroom < 5; ++delay_headroom) {
for (size_t num_sections = 2; num_sections < max_length_blocks;
++num_sections) {
@ -163,11 +171,9 @@ TEST(SignalDependentErleEstimator, SweepSettings) {
TestInputs inputs(cfg, num_render_channels, num_capture_channels);
for (size_t n = 0; n < 10; ++n) {
inputs.Update();
s.Update(inputs.GetRenderBuffer(), inputs.GetH2(),
inputs.GetX2(), inputs.GetY2(), inputs.GetE2(),
average_erle, inputs.GetConvergedFilters());
}
}
s.Update(inputs.GetRenderBuffer(), inputs.GetH2(), inputs.GetX2(),
inputs.GetY2(), inputs.GetE2(), average_erle,
inputs.GetConvergedFilters());
}
}
}
@ -175,9 +181,9 @@ TEST(SignalDependentErleEstimator, SweepSettings) {
}
}
TEST(SignalDependentErleEstimator, LongerRun) {
for (size_t num_render_channels : {1, 2, 4}) {
for (size_t num_capture_channels : {1, 2, 4}) {
TEST_P(SignalDependentErleEstimatorMultiChannel, LongerRun) {
const size_t num_render_channels = std::get<0>(GetParam());
const size_t num_capture_channels = std::get<1>(GetParam());
EchoCanceller3Config cfg;
cfg.filter.main.length_blocks = 2;
cfg.filter.main_initial.length_blocks = 1;
@ -199,7 +205,5 @@ TEST(SignalDependentErleEstimator, LongerRun) {
inputs.GetConvergedFilters());
}
}
}
}
} // namespace webrtc

View File

@ -231,33 +231,6 @@ TEST(Subtractor, Convergence) {
}
}
// Verifies that the subtractor is able to converge on correlated data.
TEST(Subtractor, ConvergenceMultiChannel) {
#if defined(NDEBUG)
const size_t kNumRenderChannelsToTest[] = {1, 2, 8};
const size_t kNumCaptureChannelsToTest[] = {1, 2, 4};
#else
const size_t kNumRenderChannelsToTest[] = {1, 2};
const size_t kNumCaptureChannelsToTest[] = {1, 2};
#endif
std::vector<int> blocks_with_echo_path_changes;
for (size_t num_render_channels : kNumRenderChannelsToTest) {
for (size_t num_capture_channels : kNumCaptureChannelsToTest) {
SCOPED_TRACE(
ProduceDebugText(num_render_channels, num_render_channels, 64, 20));
size_t num_blocks_to_process = 2500 * num_render_channels;
std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
num_render_channels, num_capture_channels, num_blocks_to_process, 64,
20, 20, false, blocks_with_echo_path_changes);
for (float echo_to_nearend_power : echo_to_nearend_powers) {
EXPECT_GT(0.1f, echo_to_nearend_power);
}
}
}
}
// Verifies that the subtractor is able to handle the case when the main filter
// is longer than the shadow filter.
TEST(Subtractor, MainFilterLongerThanShadowFilter) {
@ -297,23 +270,68 @@ TEST(Subtractor, NonConvergenceOnUncorrelatedSignals) {
}
}
// Verifies that the subtractor does not converge on uncorrelated signals.
TEST(Subtractor, NonConvergenceOnUncorrelatedSignalsMultiChannel) {
class SubtractorMultiChannelUpToEightRender
: public ::testing::Test,
public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
#if defined(NDEBUG)
INSTANTIATE_TEST_SUITE_P(NonDebugMultiChannel,
SubtractorMultiChannelUpToEightRender,
::testing::Combine(::testing::Values(1, 2, 8),
::testing::Values(1, 2, 4)));
#else
INSTANTIATE_TEST_SUITE_P(DebugMultiChannel,
SubtractorMultiChannelUpToEightRender,
::testing::Combine(::testing::Values(1, 2),
::testing::Values(1, 2)));
#endif
// Verifies that the subtractor is able to converge on correlated data.
TEST_P(SubtractorMultiChannelUpToEightRender, Convergence) {
const size_t num_render_channels = std::get<0>(GetParam());
const size_t num_capture_channels = std::get<1>(GetParam());
std::vector<int> blocks_with_echo_path_changes;
size_t num_blocks_to_process = 2500 * num_render_channels;
std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
num_render_channels, num_capture_channels, num_blocks_to_process, 64, 20,
20, false, blocks_with_echo_path_changes);
for (float echo_to_nearend_power : echo_to_nearend_powers) {
EXPECT_GT(0.1f, echo_to_nearend_power);
}
}
class SubtractorMultiChannelUpToFourRender
: public ::testing::Test,
public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
#if defined(NDEBUG)
INSTANTIATE_TEST_SUITE_P(NonDebugMultiChannel,
SubtractorMultiChannelUpToFourRender,
::testing::Combine(::testing::Values(1, 2, 4),
::testing::Values(1, 2, 4)));
#else
INSTANTIATE_TEST_SUITE_P(DebugMultiChannel,
SubtractorMultiChannelUpToFourRender,
::testing::Combine(::testing::Values(1, 2),
::testing::Values(1, 2)));
#endif
// Verifies that the subtractor does not converge on uncorrelated signals.
TEST_P(SubtractorMultiChannelUpToFourRender,
NonConvergenceOnUncorrelatedSignals) {
const size_t num_render_channels = std::get<0>(GetParam());
const size_t num_capture_channels = std::get<1>(GetParam());
std::vector<int> blocks_with_echo_path_changes;
for (size_t num_render_channels : {1, 2, 4}) {
for (size_t num_capture_channels : {1, 2, 4}) {
SCOPED_TRACE(
ProduceDebugText(num_render_channels, num_render_channels, 64, 20));
size_t num_blocks_to_process = 5000 * num_render_channels;
std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
num_render_channels, num_capture_channels, num_blocks_to_process, 64,
20, 20, true, blocks_with_echo_path_changes);
num_render_channels, num_capture_channels, num_blocks_to_process, 64, 20,
20, true, blocks_with_echo_path_changes);
for (float echo_to_nearend_power : echo_to_nearend_powers) {
EXPECT_LT(.8f, echo_to_nearend_power);
EXPECT_NEAR(1.f, echo_to_nearend_power, 0.25f);
}
}
}
}
} // namespace webrtc