aec3: Support AVX2/FMA intrinsics in AEC3

Bug: webrtc:11663
Change-Id: Ib75eb616ef0cb62698b0d96af7ebe42e93825222
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/179006
Reviewed-by: Mirko Bonadei <mbonadei@webrtc.org>
Reviewed-by: Sam Zackrisson <saza@webrtc.org>
Reviewed-by: Per Åhgren <peah@webrtc.org>
Commit-Queue: Mirko Bonadei <mbonadei@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32023}
This commit is contained in:
Zhaoliang Ma 2020-08-31 10:20:47 +08:00 committed by Commit Bot
parent 090049c546
commit e537e9ca13
21 changed files with 945 additions and 19 deletions

View File

@ -13,13 +13,9 @@ rtc_library("aec3") {
configs += [ "..:apm_debug_dump" ]
sources = [
"adaptive_fir_filter.cc",
"adaptive_fir_filter.h",
"adaptive_fir_filter_erl.cc",
"adaptive_fir_filter_erl.h",
"aec3_common.cc",
"aec3_common.h",
"aec3_fft.cc",
"aec3_fft.h",
"aec_state.cc",
"aec_state.h",
"alignment_mixer.cc",
@ -27,7 +23,6 @@ rtc_library("aec3") {
"api_call_jitter_metrics.cc",
"api_call_jitter_metrics.h",
"block_buffer.cc",
"block_buffer.h",
"block_delay_buffer.cc",
"block_delay_buffer.h",
"block_framer.cc",
@ -66,8 +61,6 @@ rtc_library("aec3") {
"erle_estimator.cc",
"erle_estimator.h",
"fft_buffer.cc",
"fft_buffer.h",
"fft_data.h",
"filter_analyzer.cc",
"filter_analyzer.h",
"frame_blocker.cc",
@ -75,7 +68,6 @@ rtc_library("aec3") {
"fullband_erle_estimator.cc",
"fullband_erle_estimator.h",
"matched_filter.cc",
"matched_filter.h",
"matched_filter_lag_aggregator.cc",
"matched_filter_lag_aggregator.h",
"moving_average.cc",
@ -84,7 +76,6 @@ rtc_library("aec3") {
"refined_filter_update_gain.cc",
"refined_filter_update_gain.h",
"render_buffer.cc",
"render_buffer.h",
"render_delay_buffer.cc",
"render_delay_buffer.h",
"render_delay_controller.cc",
@ -106,7 +97,6 @@ rtc_library("aec3") {
"signal_dependent_erle_estimator.cc",
"signal_dependent_erle_estimator.h",
"spectrum_buffer.cc",
"spectrum_buffer.h",
"stationarity_estimator.cc",
"stationarity_estimator.h",
"subband_erle_estimator.cc",
@ -123,7 +113,6 @@ rtc_library("aec3") {
"suppression_filter.h",
"suppression_gain.cc",
"suppression_gain.h",
"vector_math.h",
]
defines = []
@ -133,6 +122,14 @@ rtc_library("aec3") {
}
deps = [
":adaptive_fir_filter",
":adaptive_fir_filter_erl",
":aec3_common",
":aec3_fft",
":fft_data",
":matched_filter",
":render_buffer",
":vector_math",
"..:apm_logging",
"..:audio_buffer",
"..:high_pass_filter",
@ -140,7 +137,6 @@ rtc_library("aec3") {
"../../../api/audio:aec3_config",
"../../../api/audio:echo_control",
"../../../common_audio:common_audio_c",
"../../../common_audio/third_party/ooura:fft_size_128",
"../../../rtc_base:checks",
"../../../rtc_base:rtc_base_approved",
"../../../rtc_base:safe_minmax",
@ -152,6 +148,127 @@ rtc_library("aec3") {
"../utility:cascaded_biquad_filter",
]
absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ]
if (current_cpu == "x86" || current_cpu == "x64") {
deps += [ ":aec3_avx2" ]
}
}
rtc_source_set("aec3_common") {
sources = [ "aec3_common.h" ]
}
rtc_source_set("aec3_fft") {
sources = [ "aec3_fft.h" ]
deps = [
":aec3_common",
":fft_data",
"../../../api:array_view",
"../../../common_audio/third_party/ooura:fft_size_128",
"../../../rtc_base:checks",
"../../../rtc_base:rtc_base_approved",
"../../../rtc_base/system:arch",
]
}
rtc_source_set("render_buffer") {
sources = [
"block_buffer.h",
"fft_buffer.h",
"render_buffer.h",
"spectrum_buffer.h",
]
deps = [
":aec3_common",
":fft_data",
"../../../api:array_view",
"../../../rtc_base:checks",
"../../../rtc_base:rtc_base_approved",
"../../../rtc_base/system:arch",
]
}
rtc_source_set("adaptive_fir_filter") {
sources = [ "adaptive_fir_filter.h" ]
deps = [
":aec3_common",
":aec3_fft",
":fft_data",
":render_buffer",
"..:apm_logging",
"../../../api:array_view",
"../../../rtc_base/system:arch",
]
}
rtc_source_set("adaptive_fir_filter_erl") {
sources = [ "adaptive_fir_filter_erl.h" ]
deps = [
":aec3_common",
"../../../api:array_view",
"../../../rtc_base/system:arch",
]
}
rtc_source_set("matched_filter") {
sources = [ "matched_filter.h" ]
deps = [
":aec3_common",
"../../../api:array_view",
"../../../rtc_base:rtc_base_approved",
"../../../rtc_base/system:arch",
]
}
rtc_source_set("vector_math") {
sources = [ "vector_math.h" ]
deps = [
":aec3_common",
"../../../api:array_view",
"../../../rtc_base:checks",
"../../../rtc_base/system:arch",
]
}
rtc_source_set("fft_data") {
sources = [ "fft_data.h" ]
deps = [
":aec3_common",
"../../../api:array_view",
"../../../rtc_base/system:arch",
]
}
if (current_cpu == "x86" || current_cpu == "x64") {
rtc_library("aec3_avx2") {
configs += [ "..:apm_debug_dump" ]
sources = [
"adaptive_fir_filter_avx2.cc",
"adaptive_fir_filter_erl_avx2.cc",
"fft_data_avx2.cc",
"matched_filter_avx2.cc",
"vector_math_avx2.cc",
]
if (is_win) {
cflags = [ "/arch:AVX2" ]
} else {
cflags = [
"-mavx2",
"-mfma",
]
}
deps = [
":adaptive_fir_filter",
":adaptive_fir_filter_erl",
":fft_data",
":matched_filter",
":vector_math",
"../../../api:array_view",
"../../../rtc_base:checks",
]
}
}
if (rtc_include_tests) {
@ -171,7 +288,15 @@ if (rtc_include_tests) {
]
deps = [
":adaptive_fir_filter",
":adaptive_fir_filter_erl",
":aec3",
":aec3_common",
":aec3_fft",
":fft_data",
":matched_filter",
":render_buffer",
":vector_math",
"..:apm_logging",
"..:audio_buffer",
"..:audio_processing",

View File

@ -556,6 +556,9 @@ void AdaptiveFirFilter::Filter(const RenderBuffer& render_buffer,
case Aec3Optimization::kSse2:
aec3::ApplyFilter_Sse2(render_buffer, current_size_partitions_, H_, S);
break;
case Aec3Optimization::kAvx2:
aec3::ApplyFilter_Avx2(render_buffer, current_size_partitions_, H_, S);
break;
#endif
#if defined(WEBRTC_HAS_NEON)
case Aec3Optimization::kNeon:
@ -597,6 +600,9 @@ void AdaptiveFirFilter::ComputeFrequencyResponse(
case Aec3Optimization::kSse2:
aec3::ComputeFrequencyResponse_Sse2(current_size_partitions_, H_, H2);
break;
case Aec3Optimization::kAvx2:
aec3::ComputeFrequencyResponse_Avx2(current_size_partitions_, H_, H2);
break;
#endif
#if defined(WEBRTC_HAS_NEON)
case Aec3Optimization::kNeon:
@ -620,6 +626,10 @@ void AdaptiveFirFilter::AdaptAndUpdateSize(const RenderBuffer& render_buffer,
aec3::AdaptPartitions_Sse2(render_buffer, G, current_size_partitions_,
&H_);
break;
case Aec3Optimization::kAvx2:
aec3::AdaptPartitions_Avx2(render_buffer, G, current_size_partitions_,
&H_);
break;
#endif
#if defined(WEBRTC_HAS_NEON)
case Aec3Optimization::kNeon:

View File

@ -42,6 +42,11 @@ void ComputeFrequencyResponse_Sse2(
size_t num_partitions,
const std::vector<std::vector<FftData>>& H,
std::vector<std::array<float, kFftLengthBy2Plus1>>* H2);
void ComputeFrequencyResponse_Avx2(
size_t num_partitions,
const std::vector<std::vector<FftData>>& H,
std::vector<std::array<float, kFftLengthBy2Plus1>>* H2);
#endif
// Adapts the filter partitions.
@ -60,6 +65,11 @@ void AdaptPartitions_Sse2(const RenderBuffer& render_buffer,
const FftData& G,
size_t num_partitions,
std::vector<std::vector<FftData>>* H);
void AdaptPartitions_Avx2(const RenderBuffer& render_buffer,
const FftData& G,
size_t num_partitions,
std::vector<std::vector<FftData>>* H);
#endif
// Produces the filter output.
@ -78,6 +88,11 @@ void ApplyFilter_Sse2(const RenderBuffer& render_buffer,
size_t num_partitions,
const std::vector<std::vector<FftData>>& H,
FftData* S);
void ApplyFilter_Avx2(const RenderBuffer& render_buffer,
size_t num_partitions,
const std::vector<std::vector<FftData>>& H,
FftData* S);
#endif
} // namespace aec3

View File

@ -0,0 +1,187 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/aec3/adaptive_fir_filter.h"
#include <immintrin.h>
#include "rtc_base/checks.h"
namespace webrtc {
namespace aec3 {
// Computes and stores the frequency response of the filter.
void ComputeFrequencyResponse_Avx2(
size_t num_partitions,
const std::vector<std::vector<FftData>>& H,
std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) {
for (auto& H2_ch : *H2) {
H2_ch.fill(0.f);
}
const size_t num_render_channels = H[0].size();
RTC_DCHECK_EQ(H.size(), H2->capacity());
for (size_t p = 0; p < num_partitions; ++p) {
RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size());
for (size_t ch = 0; ch < num_render_channels; ++ch) {
for (size_t j = 0; j < kFftLengthBy2; j += 8) {
__m256 re = _mm256_loadu_ps(&H[p][ch].re[j]);
__m256 re2 = _mm256_mul_ps(re, re);
__m256 im = _mm256_loadu_ps(&H[p][ch].im[j]);
re2 = _mm256_fmadd_ps(im, im, re2);
__m256 H2_k_j = _mm256_loadu_ps(&(*H2)[p][j]);
H2_k_j = _mm256_max_ps(H2_k_j, re2);
_mm256_storeu_ps(&(*H2)[p][j], H2_k_j);
}
float H2_new = H[p][ch].re[kFftLengthBy2] * H[p][ch].re[kFftLengthBy2] +
H[p][ch].im[kFftLengthBy2] * H[p][ch].im[kFftLengthBy2];
(*H2)[p][kFftLengthBy2] = std::max((*H2)[p][kFftLengthBy2], H2_new);
}
}
}
// Adapts the filter partitions.
void AdaptPartitions_Avx2(const RenderBuffer& render_buffer,
const FftData& G,
size_t num_partitions,
std::vector<std::vector<FftData>>* H) {
rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
render_buffer.GetFftBuffer();
const size_t num_render_channels = render_buffer_data[0].size();
const size_t lim1 = std::min(
render_buffer_data.size() - render_buffer.Position(), num_partitions);
const size_t lim2 = num_partitions;
constexpr size_t kNumEightBinBands = kFftLengthBy2 / 8;
size_t X_partition = render_buffer.Position();
size_t limit = lim1;
size_t p = 0;
do {
for (; p < limit; ++p, ++X_partition) {
for (size_t ch = 0; ch < num_render_channels; ++ch) {
FftData& H_p_ch = (*H)[p][ch];
const FftData& X = render_buffer_data[X_partition][ch];
for (size_t k = 0, n = 0; n < kNumEightBinBands; ++n, k += 8) {
const __m256 G_re = _mm256_loadu_ps(&G.re[k]);
const __m256 G_im = _mm256_loadu_ps(&G.im[k]);
const __m256 X_re = _mm256_loadu_ps(&X.re[k]);
const __m256 X_im = _mm256_loadu_ps(&X.im[k]);
const __m256 H_re = _mm256_loadu_ps(&H_p_ch.re[k]);
const __m256 H_im = _mm256_loadu_ps(&H_p_ch.im[k]);
const __m256 a = _mm256_mul_ps(X_re, G_re);
const __m256 b = _mm256_mul_ps(X_im, G_im);
const __m256 c = _mm256_mul_ps(X_re, G_im);
const __m256 d = _mm256_mul_ps(X_im, G_re);
const __m256 e = _mm256_add_ps(a, b);
const __m256 f = _mm256_sub_ps(c, d);
const __m256 g = _mm256_add_ps(H_re, e);
const __m256 h = _mm256_add_ps(H_im, f);
_mm256_storeu_ps(&H_p_ch.re[k], g);
_mm256_storeu_ps(&H_p_ch.im[k], h);
}
}
}
X_partition = 0;
limit = lim2;
} while (p < lim2);
X_partition = render_buffer.Position();
limit = lim1;
p = 0;
do {
for (; p < limit; ++p, ++X_partition) {
for (size_t ch = 0; ch < num_render_channels; ++ch) {
FftData& H_p_ch = (*H)[p][ch];
const FftData& X = render_buffer_data[X_partition][ch];
H_p_ch.re[kFftLengthBy2] += X.re[kFftLengthBy2] * G.re[kFftLengthBy2] +
X.im[kFftLengthBy2] * G.im[kFftLengthBy2];
H_p_ch.im[kFftLengthBy2] += X.re[kFftLengthBy2] * G.im[kFftLengthBy2] -
X.im[kFftLengthBy2] * G.re[kFftLengthBy2];
}
}
X_partition = 0;
limit = lim2;
} while (p < lim2);
}
// Produces the filter output (AVX2 variant).
void ApplyFilter_Avx2(const RenderBuffer& render_buffer,
size_t num_partitions,
const std::vector<std::vector<FftData>>& H,
FftData* S) {
RTC_DCHECK_GE(H.size(), H.size() - 1);
S->re.fill(0.f);
S->im.fill(0.f);
rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
render_buffer.GetFftBuffer();
const size_t num_render_channels = render_buffer_data[0].size();
const size_t lim1 = std::min(
render_buffer_data.size() - render_buffer.Position(), num_partitions);
const size_t lim2 = num_partitions;
constexpr size_t kNumEightBinBands = kFftLengthBy2 / 8;
size_t X_partition = render_buffer.Position();
size_t p = 0;
size_t limit = lim1;
do {
for (; p < limit; ++p, ++X_partition) {
for (size_t ch = 0; ch < num_render_channels; ++ch) {
const FftData& H_p_ch = H[p][ch];
const FftData& X = render_buffer_data[X_partition][ch];
for (size_t k = 0, n = 0; n < kNumEightBinBands; ++n, k += 8) {
const __m256 X_re = _mm256_loadu_ps(&X.re[k]);
const __m256 X_im = _mm256_loadu_ps(&X.im[k]);
const __m256 H_re = _mm256_loadu_ps(&H_p_ch.re[k]);
const __m256 H_im = _mm256_loadu_ps(&H_p_ch.im[k]);
const __m256 S_re = _mm256_loadu_ps(&S->re[k]);
const __m256 S_im = _mm256_loadu_ps(&S->im[k]);
const __m256 a = _mm256_mul_ps(X_re, H_re);
const __m256 b = _mm256_mul_ps(X_im, H_im);
const __m256 c = _mm256_mul_ps(X_re, H_im);
const __m256 d = _mm256_mul_ps(X_im, H_re);
const __m256 e = _mm256_sub_ps(a, b);
const __m256 f = _mm256_add_ps(c, d);
const __m256 g = _mm256_add_ps(S_re, e);
const __m256 h = _mm256_add_ps(S_im, f);
_mm256_storeu_ps(&S->re[k], g);
_mm256_storeu_ps(&S->im[k], h);
}
}
}
limit = lim2;
X_partition = 0;
} while (p < lim2);
X_partition = render_buffer.Position();
p = 0;
limit = lim1;
do {
for (; p < limit; ++p, ++X_partition) {
for (size_t ch = 0; ch < num_render_channels; ++ch) {
const FftData& H_p_ch = H[p][ch];
const FftData& X = render_buffer_data[X_partition][ch];
S->re[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2] -
X.im[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2];
S->im[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2] +
X.im[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2];
}
}
limit = lim2;
X_partition = 0;
} while (p < lim2);
}
} // namespace aec3
} // namespace webrtc

View File

@ -85,10 +85,12 @@ void ComputeErl(const Aec3Optimization& optimization,
case Aec3Optimization::kSse2:
aec3::ErlComputer_SSE2(H2, erl);
break;
case Aec3Optimization::kAvx2:
aec3::ErlComputer_AVX2(H2, erl);
break;
#endif
#if defined(WEBRTC_HAS_NEON)
case Aec3Optimization::kNeon:
aec3::ErlComputer_NEON(H2, erl);
break;
#endif

View File

@ -36,6 +36,10 @@ void ErlComputer_NEON(
void ErlComputer_SSE2(
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
rtc::ArrayView<float> erl);
void ErlComputer_AVX2(
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
rtc::ArrayView<float> erl);
#endif
} // namespace aec3

View File

@ -0,0 +1,37 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h"
#include <immintrin.h>
namespace webrtc {
namespace aec3 {
// Computes and stores the echo return loss estimate of the filter, which is the
// sum of the partition frequency responses.
void ErlComputer_AVX2(
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
rtc::ArrayView<float> erl) {
std::fill(erl.begin(), erl.end(), 0.f);
for (auto& H2_j : H2) {
for (size_t k = 0; k < kFftLengthBy2; k += 8) {
const __m256 H2_j_k = _mm256_loadu_ps(&H2_j[k]);
__m256 erl_k = _mm256_loadu_ps(&erl[k]);
erl_k = _mm256_add_ps(erl_k, H2_j_k);
_mm256_storeu_ps(&erl[k], erl_k);
}
erl[kFftLengthBy2] += H2_j[kFftLengthBy2];
}
}
} // namespace aec3
} // namespace webrtc

View File

@ -75,6 +75,31 @@ TEST(AdaptiveFirFilter, UpdateErlSse2Optimization) {
}
}
// Verifies that the optimized method for echo return loss computation is
// bitexact to the reference counterpart.
TEST(AdaptiveFirFilter, UpdateErlAvx2Optimization) {
bool use_avx2 = (WebRtc_GetCPUInfo(kAVX2) != 0);
if (use_avx2) {
const size_t kNumPartitions = 12;
std::vector<std::array<float, kFftLengthBy2Plus1>> H2(kNumPartitions);
std::array<float, kFftLengthBy2Plus1> erl;
std::array<float, kFftLengthBy2Plus1> erl_AVX2;
for (size_t j = 0; j < H2.size(); ++j) {
for (size_t k = 0; k < H2[j].size(); ++k) {
H2[j][k] = k + j / 3.f;
}
}
ErlComputer(H2, erl);
ErlComputer_AVX2(H2, erl_AVX2);
for (size_t j = 0; j < erl.size(); ++j) {
EXPECT_FLOAT_EQ(erl[j], erl_AVX2[j]);
}
}
}
#endif
} // namespace aec3

View File

@ -246,6 +246,81 @@ TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels,
}
}
// Verifies that the optimized methods for filter adaptation are bitexact to
// their reference counterparts.
TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels,
FilterAdaptationAvx2Optimizations) {
const size_t num_render_channels = GetParam();
constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
bool use_avx2 = (WebRtc_GetCPUInfo(kAVX2) != 0);
if (use_avx2) {
for (size_t num_partitions : {2, 5, 12, 30, 50}) {
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz,
num_render_channels));
Random random_generator(42U);
std::vector<std::vector<std::vector<float>>> x(
kNumBands,
std::vector<std::vector<float>>(num_render_channels,
std::vector<float>(kBlockSize, 0.f)));
FftData S_C;
FftData S_Avx2;
FftData G;
Aec3Fft fft;
std::vector<std::vector<FftData>> H_C(
num_partitions, std::vector<FftData>(num_render_channels));
std::vector<std::vector<FftData>> H_Avx2(
num_partitions, std::vector<FftData>(num_render_channels));
for (size_t p = 0; p < num_partitions; ++p) {
for (size_t ch = 0; ch < num_render_channels; ++ch) {
H_C[p][ch].Clear();
H_Avx2[p][ch].Clear();
}
}
for (size_t k = 0; k < 500; ++k) {
for (size_t band = 0; band < x.size(); ++band) {
for (size_t ch = 0; ch < x[band].size(); ++ch) {
RandomizeSampleVector(&random_generator, x[band][ch]);
}
}
render_delay_buffer->Insert(x);
if (k == 0) {
render_delay_buffer->Reset();
}
render_delay_buffer->PrepareCaptureProcessing();
auto* const render_buffer = render_delay_buffer->GetRenderBuffer();
ApplyFilter_Avx2(*render_buffer, num_partitions, H_Avx2, &S_Avx2);
ApplyFilter(*render_buffer, num_partitions, H_C, &S_C);
for (size_t j = 0; j < S_C.re.size(); ++j) {
EXPECT_FLOAT_EQ(S_C.re[j], S_Avx2.re[j]);
EXPECT_FLOAT_EQ(S_C.im[j], S_Avx2.im[j]);
}
std::for_each(G.re.begin(), G.re.end(),
[&](float& a) { a = random_generator.Rand<float>(); });
std::for_each(G.im.begin(), G.im.end(),
[&](float& a) { a = random_generator.Rand<float>(); });
AdaptPartitions_Avx2(*render_buffer, G, num_partitions, &H_Avx2);
AdaptPartitions(*render_buffer, G, num_partitions, &H_C);
for (size_t p = 0; p < num_partitions; ++p) {
for (size_t ch = 0; ch < num_render_channels; ++ch) {
for (size_t j = 0; j < H_C[p][ch].re.size(); ++j) {
EXPECT_FLOAT_EQ(H_C[p][ch].re[j], H_Avx2[p][ch].re[j]);
EXPECT_FLOAT_EQ(H_C[p][ch].im[j], H_Avx2[p][ch].im[j]);
}
}
}
}
}
}
}
// Verifies that the optimized method for frequency response computation is
// bitexact to the reference counterpart.
TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels,
@ -281,6 +356,41 @@ TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels,
}
}
// Verifies that the optimized method for frequency response computation is
// bitexact to the reference counterpart.
TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels,
ComputeFrequencyResponseAvx2Optimization) {
const size_t num_render_channels = GetParam();
bool use_avx2 = (WebRtc_GetCPUInfo(kAVX2) != 0);
if (use_avx2) {
for (size_t num_partitions : {2, 5, 12, 30, 50}) {
std::vector<std::vector<FftData>> H(
num_partitions, std::vector<FftData>(num_render_channels));
std::vector<std::array<float, kFftLengthBy2Plus1>> H2(num_partitions);
std::vector<std::array<float, kFftLengthBy2Plus1>> H2_Avx2(
num_partitions);
for (size_t p = 0; p < num_partitions; ++p) {
for (size_t ch = 0; ch < num_render_channels; ++ch) {
for (size_t k = 0; k < H[p][ch].re.size(); ++k) {
H[p][ch].re[k] = k + p / 3.f + ch;
H[p][ch].im[k] = p + k / 7.f - ch;
}
}
}
ComputeFrequencyResponse(num_partitions, H, &H2);
ComputeFrequencyResponse_Avx2(num_partitions, H, &H2_Avx2);
for (size_t p = 0; p < num_partitions; ++p) {
for (size_t k = 0; k < H2[p].size(); ++k) {
EXPECT_FLOAT_EQ(H2[p][k], H2_Avx2[p][k]);
}
}
}
}
}
#endif
#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)

View File

@ -20,7 +20,9 @@ namespace webrtc {
Aec3Optimization DetectOptimization() {
#if defined(WEBRTC_ARCH_X86_FAMILY)
if (WebRtc_GetCPUInfo(kSSE2) != 0) {
if (WebRtc_GetCPUInfo(kAVX2) != 0) {
return Aec3Optimization::kAvx2;
} else if (WebRtc_GetCPUInfo(kSSE2) != 0) {
return Aec3Optimization::kSse2;
}
#endif

View File

@ -23,7 +23,7 @@ namespace webrtc {
#define ALIGN16_END __attribute__((aligned(16)))
#endif
enum class Aec3Optimization { kNone, kSse2, kNeon };
enum class Aec3Optimization { kNone, kSse2, kAvx2, kNeon };
constexpr int kNumBlocksPerSecond = 250;

View File

@ -40,6 +40,9 @@ struct FftData {
im.fill(0.f);
}
// Computes the power spectrum of the data.
void SpectrumAVX2(rtc::ArrayView<float> power_spectrum) const;
// Computes the power spectrum of the data.
void Spectrum(Aec3Optimization optimization,
rtc::ArrayView<float> power_spectrum) const {
@ -60,6 +63,9 @@ struct FftData {
power_spectrum[kFftLengthBy2] = re[kFftLengthBy2] * re[kFftLengthBy2] +
im[kFftLengthBy2] * im[kFftLengthBy2];
} break;
case Aec3Optimization::kAvx2:
SpectrumAVX2(power_spectrum);
break;
#endif
default:
std::transform(re.begin(), re.end(), im.begin(), power_spectrum.begin(),

View File

@ -0,0 +1,33 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/aec3/fft_data.h"
#include <immintrin.h>
#include "api/array_view.h"
namespace webrtc {
// Computes the power spectrum of the data.
void FftData::SpectrumAVX2(rtc::ArrayView<float> power_spectrum) const {
RTC_DCHECK_EQ(kFftLengthBy2Plus1, power_spectrum.size());
for (size_t k = 0; k < kFftLengthBy2; k += 8) {
__m256 r = _mm256_loadu_ps(&re[k]);
__m256 i = _mm256_loadu_ps(&im[k]);
__m256 ii = _mm256_mul_ps(i, i);
ii = _mm256_fmadd_ps(r, r, ii);
_mm256_storeu_ps(&power_spectrum[k], ii);
}
power_spectrum[kFftLengthBy2] = re[kFftLengthBy2] * re[kFftLengthBy2] +
im[kFftLengthBy2] * im[kFftLengthBy2];
}
} // namespace webrtc

View File

@ -19,7 +19,7 @@ namespace webrtc {
#if defined(WEBRTC_ARCH_X86_FAMILY)
// Verifies that the optimized methods are bitexact to their reference
// counterparts.
TEST(FftData, TestOptimizations) {
TEST(FftData, TestSse2Optimizations) {
if (WebRtc_GetCPUInfo(kSSE2) != 0) {
FftData x;
@ -39,6 +39,29 @@ TEST(FftData, TestOptimizations) {
EXPECT_EQ(spectrum, spectrum_sse2);
}
}
// Verifies that the optimized methods are bitexact to their reference
// counterparts.
TEST(FftData, TestAvx2Optimizations) {
if (WebRtc_GetCPUInfo(kAVX2) != 0) {
FftData x;
for (size_t k = 0; k < x.re.size(); ++k) {
x.re[k] = k + 1;
}
x.im[0] = x.im[x.im.size() - 1] = 0.f;
for (size_t k = 1; k < x.im.size() - 1; ++k) {
x.im[k] = 2.f * (k + 1);
}
std::array<float, kFftLengthBy2Plus1> spectrum;
std::array<float, kFftLengthBy2Plus1> spectrum_avx2;
x.Spectrum(Aec3Optimization::kNone, spectrum);
x.Spectrum(Aec3Optimization::kAvx2, spectrum_avx2);
EXPECT_EQ(spectrum, spectrum_avx2);
}
}
#endif
#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)

View File

@ -364,6 +364,11 @@ void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
smoothing_, render_buffer.buffer, y,
filters_[n], &filters_updated, &error_sum);
break;
case Aec3Optimization::kAvx2:
aec3::MatchedFilterCore_AVX2(x_start_index, x2_sum_threshold,
smoothing_, render_buffer.buffer, y,
filters_[n], &filters_updated, &error_sum);
break;
#endif
#if defined(WEBRTC_HAS_NEON)
case Aec3Optimization::kNeon:

View File

@ -53,6 +53,16 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
bool* filters_updated,
float* error_sum);
// Filter core for the matched filter that is optimized for AVX2.
void MatchedFilterCore_AVX2(size_t x_start_index,
float x2_sum_threshold,
float smoothing,
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum);
#endif
// Filter core for the matched filter.

View File

@ -0,0 +1,132 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/aec3/matched_filter.h"
#include <immintrin.h>
#include "rtc_base/checks.h"
namespace webrtc {
namespace aec3 {
void MatchedFilterCore_AVX2(size_t x_start_index,
float x2_sum_threshold,
float smoothing,
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum) {
const int h_size = static_cast<int>(h.size());
const int x_size = static_cast<int>(x.size());
RTC_DCHECK_EQ(0, h_size % 8);
// Process for all samples in the sub-block.
for (size_t i = 0; i < y.size(); ++i) {
// Apply the matched filter as filter * x, and compute x * x.
RTC_DCHECK_GT(x_size, x_start_index);
const float* x_p = &x[x_start_index];
const float* h_p = &h[0];
// Initialize values for the accumulation.
__m256 s_256 = _mm256_set1_ps(0);
__m256 x2_sum_256 = _mm256_set1_ps(0);
float x2_sum = 0.f;
float s = 0;
// Compute loop chunk sizes until, and after, the wraparound of the circular
// buffer for x.
const int chunk1 =
std::min(h_size, static_cast<int>(x_size - x_start_index));
// Perform the loop in two chunks.
const int chunk2 = h_size - chunk1;
for (int limit : {chunk1, chunk2}) {
// Perform 256 bit vector operations.
const int limit_by_8 = limit >> 3;
for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) {
// Load the data into 256 bit vectors.
__m256 x_k = _mm256_loadu_ps(x_p);
__m256 h_k = _mm256_loadu_ps(h_p);
// Compute and accumulate x * x and h * x.
x2_sum_256 = _mm256_fmadd_ps(x_k, x_k, x2_sum_256);
s_256 = _mm256_fmadd_ps(h_k, x_k, s_256);
}
// Perform non-vector operations for any remaining items.
for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) {
const float x_k = *x_p;
x2_sum += x_k * x_k;
s += *h_p * x_k;
}
x_p = &x[0];
}
// Sum components together.
__m128 x2_sum_128 = _mm_add_ps(_mm256_extractf128_ps(x2_sum_256, 0),
_mm256_extractf128_ps(x2_sum_256, 1));
__m128 s_128 = _mm_add_ps(_mm256_extractf128_ps(s_256, 0),
_mm256_extractf128_ps(s_256, 1));
// Combine the accumulated vector and scalar values.
float* v = reinterpret_cast<float*>(&x2_sum_128);
x2_sum += v[0] + v[1] + v[2] + v[3];
v = reinterpret_cast<float*>(&s_128);
s += v[0] + v[1] + v[2] + v[3];
// Compute the matched filter error.
float e = y[i] - s;
const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
(*error_sum) += e * e;
// Update the matched filter estimate in an NLMS manner.
if (x2_sum > x2_sum_threshold && !saturation) {
RTC_DCHECK_LT(0.f, x2_sum);
const float alpha = smoothing * e / x2_sum;
const __m256 alpha_256 = _mm256_set1_ps(alpha);
// filter = filter + smoothing * (y - filter * x) * x / x * x.
float* h_p = &h[0];
x_p = &x[x_start_index];
// Perform the loop in two chunks.
for (int limit : {chunk1, chunk2}) {
// Perform 256 bit vector operations.
const int limit_by_8 = limit >> 3;
for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) {
// Load the data into 256 bit vectors.
__m256 h_k = _mm256_loadu_ps(h_p);
__m256 x_k = _mm256_loadu_ps(x_p);
// Compute h = h + alpha * x.
h_k = _mm256_fmadd_ps(x_k, alpha_256, h_k);
// Store the result.
_mm256_storeu_ps(h_p, h_k);
}
// Perform non-vector operations for any remaining items.
for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) {
*h_p += alpha * *x_p;
}
x_p = &x[0];
}
*filters_updated = true;
}
x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
}
}
} // namespace aec3
} // namespace webrtc

View File

@ -133,6 +133,47 @@ TEST(MatchedFilter, TestSse2Optimizations) {
}
}
TEST(MatchedFilter, TestAvx2Optimizations) {
bool use_avx2 = (WebRtc_GetCPUInfo(kAVX2) != 0);
if (use_avx2) {
Random random_generator(42U);
constexpr float kSmoothing = 0.7f;
for (auto down_sampling_factor : kDownSamplingFactors) {
const size_t sub_block_size = kBlockSize / down_sampling_factor;
std::vector<float> x(2000);
RandomizeSampleVector(&random_generator, x);
std::vector<float> y(sub_block_size);
std::vector<float> h_AVX2(512);
std::vector<float> h(512);
int x_index = 0;
for (int k = 0; k < 1000; ++k) {
RandomizeSampleVector(&random_generator, y);
bool filters_updated = false;
float error_sum = 0.f;
bool filters_updated_AVX2 = false;
float error_sum_AVX2 = 0.f;
MatchedFilterCore_AVX2(x_index, h.size() * 150.f * 150.f, kSmoothing, x,
y, h_AVX2, &filters_updated_AVX2,
&error_sum_AVX2);
MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y,
h, &filters_updated, &error_sum);
EXPECT_EQ(filters_updated, filters_updated_AVX2);
EXPECT_NEAR(error_sum, error_sum_AVX2, error_sum / 100000.f);
for (size_t j = 0; j < h.size(); ++j) {
EXPECT_NEAR(h[j], h_AVX2[j], 0.00001f);
}
x_index = (x_index + sub_block_size) % x.size();
}
}
}
}
#endif
// Verifies that the matched filter produces proper lag estimates for

View File

@ -40,6 +40,7 @@ class VectorMath {
: optimization_(optimization) {}
// Elementwise square root.
void SqrtAVX2(rtc::ArrayView<float> x);
void Sqrt(rtc::ArrayView<float> x) {
switch (optimization_) {
#if defined(WEBRTC_ARCH_X86_FAMILY)
@ -58,6 +59,9 @@ class VectorMath {
x[j] = sqrtf(x[j]);
}
} break;
case Aec3Optimization::kAvx2:
SqrtAVX2(x);
break;
#endif
#if defined(WEBRTC_HAS_NEON)
case Aec3Optimization::kNeon: {
@ -110,6 +114,9 @@ class VectorMath {
}
// Elementwise vector multiplication z = x * y.
void MultiplyAVX2(rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> z);
void Multiply(rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> z) {
@ -133,6 +140,9 @@ class VectorMath {
z[j] = x[j] * y[j];
}
} break;
case Aec3Optimization::kAvx2:
MultiplyAVX2(x, y, z);
break;
#endif
#if defined(WEBRTC_HAS_NEON)
case Aec3Optimization::kNeon: {
@ -159,6 +169,7 @@ class VectorMath {
}
// Elementwise vector accumulation z += x.
void AccumulateAVX2(rtc::ArrayView<const float> x, rtc::ArrayView<float> z);
void Accumulate(rtc::ArrayView<const float> x, rtc::ArrayView<float> z) {
RTC_DCHECK_EQ(z.size(), x.size());
switch (optimization_) {
@ -179,6 +190,9 @@ class VectorMath {
z[j] += x[j];
}
} break;
case Aec3Optimization::kAvx2:
AccumulateAVX2(x, z);
break;
#endif
#if defined(WEBRTC_HAS_NEON)
case Aec3Optimization::kNeon: {

View File

@ -0,0 +1,82 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/aec3/vector_math.h"
#include <immintrin.h>
#include <math.h>
#include "api/array_view.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace aec3 {
// Elementwise square root.
void VectorMath::SqrtAVX2(rtc::ArrayView<float> x) {
const int x_size = static_cast<int>(x.size());
const int vector_limit = x_size >> 3;
int j = 0;
for (; j < vector_limit * 8; j += 8) {
__m256 g = _mm256_loadu_ps(&x[j]);
g = _mm256_sqrt_ps(g);
_mm256_storeu_ps(&x[j], g);
}
for (; j < x_size; ++j) {
x[j] = sqrtf(x[j]);
}
}
// Elementwise vector multiplication z = x * y.
void VectorMath::MultiplyAVX2(rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> z) {
RTC_DCHECK_EQ(z.size(), x.size());
RTC_DCHECK_EQ(z.size(), y.size());
const int x_size = static_cast<int>(x.size());
const int vector_limit = x_size >> 3;
int j = 0;
for (; j < vector_limit * 8; j += 8) {
const __m256 x_j = _mm256_loadu_ps(&x[j]);
const __m256 y_j = _mm256_loadu_ps(&y[j]);
const __m256 z_j = _mm256_mul_ps(x_j, y_j);
_mm256_storeu_ps(&z[j], z_j);
}
for (; j < x_size; ++j) {
z[j] = x[j] * y[j];
}
}
// Elementwise vector accumulation z += x.
void VectorMath::AccumulateAVX2(rtc::ArrayView<const float> x,
rtc::ArrayView<float> z) {
RTC_DCHECK_EQ(z.size(), x.size());
const int x_size = static_cast<int>(x.size());
const int vector_limit = x_size >> 3;
int j = 0;
for (; j < vector_limit * 8; j += 8) {
const __m256 x_j = _mm256_loadu_ps(&x[j]);
__m256 z_j = _mm256_loadu_ps(&z[j]);
z_j = _mm256_add_ps(x_j, z_j);
_mm256_storeu_ps(&z[j], z_j);
}
for (; j < x_size; ++j) {
z[j] += x[j];
}
}
} // namespace aec3
} // namespace webrtc

View File

@ -79,7 +79,7 @@ TEST(VectorMath, Accumulate) {
#if defined(WEBRTC_ARCH_X86_FAMILY)
TEST(VectorMath, Sqrt) {
TEST(VectorMath, Sse2Sqrt) {
if (WebRtc_GetCPUInfo(kSSE2) != 0) {
std::array<float, kFftLengthBy2Plus1> x;
std::array<float, kFftLengthBy2Plus1> z;
@ -101,7 +101,29 @@ TEST(VectorMath, Sqrt) {
}
}
TEST(VectorMath, Multiply) {
TEST(VectorMath, Avx2Sqrt) {
if (WebRtc_GetCPUInfo(kAVX2) != 0) {
std::array<float, kFftLengthBy2Plus1> x;
std::array<float, kFftLengthBy2Plus1> z;
std::array<float, kFftLengthBy2Plus1> z_avx2;
for (size_t k = 0; k < x.size(); ++k) {
x[k] = (2.f / 3.f) * k;
}
std::copy(x.begin(), x.end(), z.begin());
aec3::VectorMath(Aec3Optimization::kNone).Sqrt(z);
std::copy(x.begin(), x.end(), z_avx2.begin());
aec3::VectorMath(Aec3Optimization::kAvx2).Sqrt(z_avx2);
EXPECT_EQ(z, z_avx2);
for (size_t k = 0; k < z.size(); ++k) {
EXPECT_FLOAT_EQ(z[k], z_avx2[k]);
EXPECT_FLOAT_EQ(sqrtf(x[k]), z_avx2[k]);
}
}
}
TEST(VectorMath, Sse2Multiply) {
if (WebRtc_GetCPUInfo(kSSE2) != 0) {
std::array<float, kFftLengthBy2Plus1> x;
std::array<float, kFftLengthBy2Plus1> y;
@ -122,7 +144,28 @@ TEST(VectorMath, Multiply) {
}
}
TEST(VectorMath, Accumulate) {
TEST(VectorMath, Avx2Multiply) {
if (WebRtc_GetCPUInfo(kAVX2) != 0) {
std::array<float, kFftLengthBy2Plus1> x;
std::array<float, kFftLengthBy2Plus1> y;
std::array<float, kFftLengthBy2Plus1> z;
std::array<float, kFftLengthBy2Plus1> z_avx2;
for (size_t k = 0; k < x.size(); ++k) {
x[k] = k;
y[k] = (2.f / 3.f) * k;
}
aec3::VectorMath(Aec3Optimization::kNone).Multiply(x, y, z);
aec3::VectorMath(Aec3Optimization::kAvx2).Multiply(x, y, z_avx2);
for (size_t k = 0; k < z.size(); ++k) {
EXPECT_FLOAT_EQ(z[k], z_avx2[k]);
EXPECT_FLOAT_EQ(x[k] * y[k], z_avx2[k]);
}
}
}
TEST(VectorMath, Sse2Accumulate) {
if (WebRtc_GetCPUInfo(kSSE2) != 0) {
std::array<float, kFftLengthBy2Plus1> x;
std::array<float, kFftLengthBy2Plus1> z;
@ -141,6 +184,26 @@ TEST(VectorMath, Accumulate) {
}
}
}
TEST(VectorMath, Avx2Accumulate) {
if (WebRtc_GetCPUInfo(kAVX2) != 0) {
std::array<float, kFftLengthBy2Plus1> x;
std::array<float, kFftLengthBy2Plus1> z;
std::array<float, kFftLengthBy2Plus1> z_avx2;
for (size_t k = 0; k < x.size(); ++k) {
x[k] = k;
z[k] = z_avx2[k] = 2.f * k;
}
aec3::VectorMath(Aec3Optimization::kNone).Accumulate(x, z);
aec3::VectorMath(Aec3Optimization::kAvx2).Accumulate(x, z_avx2);
for (size_t k = 0; k < z.size(); ++k) {
EXPECT_FLOAT_EQ(z[k], z_avx2[k]);
EXPECT_FLOAT_EQ(x[k] + 2.f * x[k], z_avx2[k]);
}
}
}
#endif
} // namespace webrtc