aec3: Support AVX2/FMA intrinsics in AEC3
Bug: webrtc:11663 Change-Id: Ib75eb616ef0cb62698b0d96af7ebe42e93825222 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/179006 Reviewed-by: Mirko Bonadei <mbonadei@webrtc.org> Reviewed-by: Sam Zackrisson <saza@webrtc.org> Reviewed-by: Per Åhgren <peah@webrtc.org> Commit-Queue: Mirko Bonadei <mbonadei@webrtc.org> Cr-Commit-Position: refs/heads/master@{#32023}
This commit is contained in:
parent
090049c546
commit
e537e9ca13
@ -13,13 +13,9 @@ rtc_library("aec3") {
|
||||
configs += [ "..:apm_debug_dump" ]
|
||||
sources = [
|
||||
"adaptive_fir_filter.cc",
|
||||
"adaptive_fir_filter.h",
|
||||
"adaptive_fir_filter_erl.cc",
|
||||
"adaptive_fir_filter_erl.h",
|
||||
"aec3_common.cc",
|
||||
"aec3_common.h",
|
||||
"aec3_fft.cc",
|
||||
"aec3_fft.h",
|
||||
"aec_state.cc",
|
||||
"aec_state.h",
|
||||
"alignment_mixer.cc",
|
||||
@ -27,7 +23,6 @@ rtc_library("aec3") {
|
||||
"api_call_jitter_metrics.cc",
|
||||
"api_call_jitter_metrics.h",
|
||||
"block_buffer.cc",
|
||||
"block_buffer.h",
|
||||
"block_delay_buffer.cc",
|
||||
"block_delay_buffer.h",
|
||||
"block_framer.cc",
|
||||
@ -66,8 +61,6 @@ rtc_library("aec3") {
|
||||
"erle_estimator.cc",
|
||||
"erle_estimator.h",
|
||||
"fft_buffer.cc",
|
||||
"fft_buffer.h",
|
||||
"fft_data.h",
|
||||
"filter_analyzer.cc",
|
||||
"filter_analyzer.h",
|
||||
"frame_blocker.cc",
|
||||
@ -75,7 +68,6 @@ rtc_library("aec3") {
|
||||
"fullband_erle_estimator.cc",
|
||||
"fullband_erle_estimator.h",
|
||||
"matched_filter.cc",
|
||||
"matched_filter.h",
|
||||
"matched_filter_lag_aggregator.cc",
|
||||
"matched_filter_lag_aggregator.h",
|
||||
"moving_average.cc",
|
||||
@ -84,7 +76,6 @@ rtc_library("aec3") {
|
||||
"refined_filter_update_gain.cc",
|
||||
"refined_filter_update_gain.h",
|
||||
"render_buffer.cc",
|
||||
"render_buffer.h",
|
||||
"render_delay_buffer.cc",
|
||||
"render_delay_buffer.h",
|
||||
"render_delay_controller.cc",
|
||||
@ -106,7 +97,6 @@ rtc_library("aec3") {
|
||||
"signal_dependent_erle_estimator.cc",
|
||||
"signal_dependent_erle_estimator.h",
|
||||
"spectrum_buffer.cc",
|
||||
"spectrum_buffer.h",
|
||||
"stationarity_estimator.cc",
|
||||
"stationarity_estimator.h",
|
||||
"subband_erle_estimator.cc",
|
||||
@ -123,7 +113,6 @@ rtc_library("aec3") {
|
||||
"suppression_filter.h",
|
||||
"suppression_gain.cc",
|
||||
"suppression_gain.h",
|
||||
"vector_math.h",
|
||||
]
|
||||
|
||||
defines = []
|
||||
@ -133,6 +122,14 @@ rtc_library("aec3") {
|
||||
}
|
||||
|
||||
deps = [
|
||||
":adaptive_fir_filter",
|
||||
":adaptive_fir_filter_erl",
|
||||
":aec3_common",
|
||||
":aec3_fft",
|
||||
":fft_data",
|
||||
":matched_filter",
|
||||
":render_buffer",
|
||||
":vector_math",
|
||||
"..:apm_logging",
|
||||
"..:audio_buffer",
|
||||
"..:high_pass_filter",
|
||||
@ -140,7 +137,6 @@ rtc_library("aec3") {
|
||||
"../../../api/audio:aec3_config",
|
||||
"../../../api/audio:echo_control",
|
||||
"../../../common_audio:common_audio_c",
|
||||
"../../../common_audio/third_party/ooura:fft_size_128",
|
||||
"../../../rtc_base:checks",
|
||||
"../../../rtc_base:rtc_base_approved",
|
||||
"../../../rtc_base:safe_minmax",
|
||||
@ -152,6 +148,127 @@ rtc_library("aec3") {
|
||||
"../utility:cascaded_biquad_filter",
|
||||
]
|
||||
absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ]
|
||||
|
||||
if (current_cpu == "x86" || current_cpu == "x64") {
|
||||
deps += [ ":aec3_avx2" ]
|
||||
}
|
||||
}
|
||||
|
||||
rtc_source_set("aec3_common") {
|
||||
sources = [ "aec3_common.h" ]
|
||||
}
|
||||
|
||||
rtc_source_set("aec3_fft") {
|
||||
sources = [ "aec3_fft.h" ]
|
||||
deps = [
|
||||
":aec3_common",
|
||||
":fft_data",
|
||||
"../../../api:array_view",
|
||||
"../../../common_audio/third_party/ooura:fft_size_128",
|
||||
"../../../rtc_base:checks",
|
||||
"../../../rtc_base:rtc_base_approved",
|
||||
"../../../rtc_base/system:arch",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_source_set("render_buffer") {
|
||||
sources = [
|
||||
"block_buffer.h",
|
||||
"fft_buffer.h",
|
||||
"render_buffer.h",
|
||||
"spectrum_buffer.h",
|
||||
]
|
||||
deps = [
|
||||
":aec3_common",
|
||||
":fft_data",
|
||||
"../../../api:array_view",
|
||||
"../../../rtc_base:checks",
|
||||
"../../../rtc_base:rtc_base_approved",
|
||||
"../../../rtc_base/system:arch",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_source_set("adaptive_fir_filter") {
|
||||
sources = [ "adaptive_fir_filter.h" ]
|
||||
deps = [
|
||||
":aec3_common",
|
||||
":aec3_fft",
|
||||
":fft_data",
|
||||
":render_buffer",
|
||||
"..:apm_logging",
|
||||
"../../../api:array_view",
|
||||
"../../../rtc_base/system:arch",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_source_set("adaptive_fir_filter_erl") {
|
||||
sources = [ "adaptive_fir_filter_erl.h" ]
|
||||
deps = [
|
||||
":aec3_common",
|
||||
"../../../api:array_view",
|
||||
"../../../rtc_base/system:arch",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_source_set("matched_filter") {
|
||||
sources = [ "matched_filter.h" ]
|
||||
deps = [
|
||||
":aec3_common",
|
||||
"../../../api:array_view",
|
||||
"../../../rtc_base:rtc_base_approved",
|
||||
"../../../rtc_base/system:arch",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_source_set("vector_math") {
|
||||
sources = [ "vector_math.h" ]
|
||||
deps = [
|
||||
":aec3_common",
|
||||
"../../../api:array_view",
|
||||
"../../../rtc_base:checks",
|
||||
"../../../rtc_base/system:arch",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_source_set("fft_data") {
|
||||
sources = [ "fft_data.h" ]
|
||||
deps = [
|
||||
":aec3_common",
|
||||
"../../../api:array_view",
|
||||
"../../../rtc_base/system:arch",
|
||||
]
|
||||
}
|
||||
|
||||
if (current_cpu == "x86" || current_cpu == "x64") {
|
||||
rtc_library("aec3_avx2") {
|
||||
configs += [ "..:apm_debug_dump" ]
|
||||
sources = [
|
||||
"adaptive_fir_filter_avx2.cc",
|
||||
"adaptive_fir_filter_erl_avx2.cc",
|
||||
"fft_data_avx2.cc",
|
||||
"matched_filter_avx2.cc",
|
||||
"vector_math_avx2.cc",
|
||||
]
|
||||
|
||||
if (is_win) {
|
||||
cflags = [ "/arch:AVX2" ]
|
||||
} else {
|
||||
cflags = [
|
||||
"-mavx2",
|
||||
"-mfma",
|
||||
]
|
||||
}
|
||||
|
||||
deps = [
|
||||
":adaptive_fir_filter",
|
||||
":adaptive_fir_filter_erl",
|
||||
":fft_data",
|
||||
":matched_filter",
|
||||
":vector_math",
|
||||
"../../../api:array_view",
|
||||
"../../../rtc_base:checks",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
if (rtc_include_tests) {
|
||||
@ -171,7 +288,15 @@ if (rtc_include_tests) {
|
||||
]
|
||||
|
||||
deps = [
|
||||
":adaptive_fir_filter",
|
||||
":adaptive_fir_filter_erl",
|
||||
":aec3",
|
||||
":aec3_common",
|
||||
":aec3_fft",
|
||||
":fft_data",
|
||||
":matched_filter",
|
||||
":render_buffer",
|
||||
":vector_math",
|
||||
"..:apm_logging",
|
||||
"..:audio_buffer",
|
||||
"..:audio_processing",
|
||||
|
||||
@ -556,6 +556,9 @@ void AdaptiveFirFilter::Filter(const RenderBuffer& render_buffer,
|
||||
case Aec3Optimization::kSse2:
|
||||
aec3::ApplyFilter_Sse2(render_buffer, current_size_partitions_, H_, S);
|
||||
break;
|
||||
case Aec3Optimization::kAvx2:
|
||||
aec3::ApplyFilter_Avx2(render_buffer, current_size_partitions_, H_, S);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Aec3Optimization::kNeon:
|
||||
@ -597,6 +600,9 @@ void AdaptiveFirFilter::ComputeFrequencyResponse(
|
||||
case Aec3Optimization::kSse2:
|
||||
aec3::ComputeFrequencyResponse_Sse2(current_size_partitions_, H_, H2);
|
||||
break;
|
||||
case Aec3Optimization::kAvx2:
|
||||
aec3::ComputeFrequencyResponse_Avx2(current_size_partitions_, H_, H2);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Aec3Optimization::kNeon:
|
||||
@ -620,6 +626,10 @@ void AdaptiveFirFilter::AdaptAndUpdateSize(const RenderBuffer& render_buffer,
|
||||
aec3::AdaptPartitions_Sse2(render_buffer, G, current_size_partitions_,
|
||||
&H_);
|
||||
break;
|
||||
case Aec3Optimization::kAvx2:
|
||||
aec3::AdaptPartitions_Avx2(render_buffer, G, current_size_partitions_,
|
||||
&H_);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Aec3Optimization::kNeon:
|
||||
|
||||
@ -42,6 +42,11 @@ void ComputeFrequencyResponse_Sse2(
|
||||
size_t num_partitions,
|
||||
const std::vector<std::vector<FftData>>& H,
|
||||
std::vector<std::array<float, kFftLengthBy2Plus1>>* H2);
|
||||
|
||||
void ComputeFrequencyResponse_Avx2(
|
||||
size_t num_partitions,
|
||||
const std::vector<std::vector<FftData>>& H,
|
||||
std::vector<std::array<float, kFftLengthBy2Plus1>>* H2);
|
||||
#endif
|
||||
|
||||
// Adapts the filter partitions.
|
||||
@ -60,6 +65,11 @@ void AdaptPartitions_Sse2(const RenderBuffer& render_buffer,
|
||||
const FftData& G,
|
||||
size_t num_partitions,
|
||||
std::vector<std::vector<FftData>>* H);
|
||||
|
||||
void AdaptPartitions_Avx2(const RenderBuffer& render_buffer,
|
||||
const FftData& G,
|
||||
size_t num_partitions,
|
||||
std::vector<std::vector<FftData>>* H);
|
||||
#endif
|
||||
|
||||
// Produces the filter output.
|
||||
@ -78,6 +88,11 @@ void ApplyFilter_Sse2(const RenderBuffer& render_buffer,
|
||||
size_t num_partitions,
|
||||
const std::vector<std::vector<FftData>>& H,
|
||||
FftData* S);
|
||||
|
||||
void ApplyFilter_Avx2(const RenderBuffer& render_buffer,
|
||||
size_t num_partitions,
|
||||
const std::vector<std::vector<FftData>>& H,
|
||||
FftData* S);
|
||||
#endif
|
||||
|
||||
} // namespace aec3
|
||||
|
||||
187
modules/audio_processing/aec3/adaptive_fir_filter_avx2.cc
Normal file
187
modules/audio_processing/aec3/adaptive_fir_filter_avx2.cc
Normal file
@ -0,0 +1,187 @@
|
||||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/aec3/adaptive_fir_filter.h"
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
namespace aec3 {
|
||||
|
||||
// Computes and stores the frequency response of the filter.
|
||||
void ComputeFrequencyResponse_Avx2(
|
||||
size_t num_partitions,
|
||||
const std::vector<std::vector<FftData>>& H,
|
||||
std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) {
|
||||
for (auto& H2_ch : *H2) {
|
||||
H2_ch.fill(0.f);
|
||||
}
|
||||
|
||||
const size_t num_render_channels = H[0].size();
|
||||
RTC_DCHECK_EQ(H.size(), H2->capacity());
|
||||
for (size_t p = 0; p < num_partitions; ++p) {
|
||||
RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size());
|
||||
for (size_t ch = 0; ch < num_render_channels; ++ch) {
|
||||
for (size_t j = 0; j < kFftLengthBy2; j += 8) {
|
||||
__m256 re = _mm256_loadu_ps(&H[p][ch].re[j]);
|
||||
__m256 re2 = _mm256_mul_ps(re, re);
|
||||
__m256 im = _mm256_loadu_ps(&H[p][ch].im[j]);
|
||||
re2 = _mm256_fmadd_ps(im, im, re2);
|
||||
__m256 H2_k_j = _mm256_loadu_ps(&(*H2)[p][j]);
|
||||
H2_k_j = _mm256_max_ps(H2_k_j, re2);
|
||||
_mm256_storeu_ps(&(*H2)[p][j], H2_k_j);
|
||||
}
|
||||
float H2_new = H[p][ch].re[kFftLengthBy2] * H[p][ch].re[kFftLengthBy2] +
|
||||
H[p][ch].im[kFftLengthBy2] * H[p][ch].im[kFftLengthBy2];
|
||||
(*H2)[p][kFftLengthBy2] = std::max((*H2)[p][kFftLengthBy2], H2_new);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Adapts the filter partitions.
|
||||
void AdaptPartitions_Avx2(const RenderBuffer& render_buffer,
|
||||
const FftData& G,
|
||||
size_t num_partitions,
|
||||
std::vector<std::vector<FftData>>* H) {
|
||||
rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
|
||||
render_buffer.GetFftBuffer();
|
||||
const size_t num_render_channels = render_buffer_data[0].size();
|
||||
const size_t lim1 = std::min(
|
||||
render_buffer_data.size() - render_buffer.Position(), num_partitions);
|
||||
const size_t lim2 = num_partitions;
|
||||
constexpr size_t kNumEightBinBands = kFftLengthBy2 / 8;
|
||||
|
||||
size_t X_partition = render_buffer.Position();
|
||||
size_t limit = lim1;
|
||||
size_t p = 0;
|
||||
do {
|
||||
for (; p < limit; ++p, ++X_partition) {
|
||||
for (size_t ch = 0; ch < num_render_channels; ++ch) {
|
||||
FftData& H_p_ch = (*H)[p][ch];
|
||||
const FftData& X = render_buffer_data[X_partition][ch];
|
||||
|
||||
for (size_t k = 0, n = 0; n < kNumEightBinBands; ++n, k += 8) {
|
||||
const __m256 G_re = _mm256_loadu_ps(&G.re[k]);
|
||||
const __m256 G_im = _mm256_loadu_ps(&G.im[k]);
|
||||
const __m256 X_re = _mm256_loadu_ps(&X.re[k]);
|
||||
const __m256 X_im = _mm256_loadu_ps(&X.im[k]);
|
||||
const __m256 H_re = _mm256_loadu_ps(&H_p_ch.re[k]);
|
||||
const __m256 H_im = _mm256_loadu_ps(&H_p_ch.im[k]);
|
||||
const __m256 a = _mm256_mul_ps(X_re, G_re);
|
||||
const __m256 b = _mm256_mul_ps(X_im, G_im);
|
||||
const __m256 c = _mm256_mul_ps(X_re, G_im);
|
||||
const __m256 d = _mm256_mul_ps(X_im, G_re);
|
||||
const __m256 e = _mm256_add_ps(a, b);
|
||||
const __m256 f = _mm256_sub_ps(c, d);
|
||||
const __m256 g = _mm256_add_ps(H_re, e);
|
||||
const __m256 h = _mm256_add_ps(H_im, f);
|
||||
_mm256_storeu_ps(&H_p_ch.re[k], g);
|
||||
_mm256_storeu_ps(&H_p_ch.im[k], h);
|
||||
}
|
||||
}
|
||||
}
|
||||
X_partition = 0;
|
||||
limit = lim2;
|
||||
} while (p < lim2);
|
||||
|
||||
X_partition = render_buffer.Position();
|
||||
limit = lim1;
|
||||
p = 0;
|
||||
do {
|
||||
for (; p < limit; ++p, ++X_partition) {
|
||||
for (size_t ch = 0; ch < num_render_channels; ++ch) {
|
||||
FftData& H_p_ch = (*H)[p][ch];
|
||||
const FftData& X = render_buffer_data[X_partition][ch];
|
||||
|
||||
H_p_ch.re[kFftLengthBy2] += X.re[kFftLengthBy2] * G.re[kFftLengthBy2] +
|
||||
X.im[kFftLengthBy2] * G.im[kFftLengthBy2];
|
||||
H_p_ch.im[kFftLengthBy2] += X.re[kFftLengthBy2] * G.im[kFftLengthBy2] -
|
||||
X.im[kFftLengthBy2] * G.re[kFftLengthBy2];
|
||||
}
|
||||
}
|
||||
|
||||
X_partition = 0;
|
||||
limit = lim2;
|
||||
} while (p < lim2);
|
||||
}
|
||||
|
||||
// Produces the filter output (AVX2 variant).
|
||||
void ApplyFilter_Avx2(const RenderBuffer& render_buffer,
|
||||
size_t num_partitions,
|
||||
const std::vector<std::vector<FftData>>& H,
|
||||
FftData* S) {
|
||||
RTC_DCHECK_GE(H.size(), H.size() - 1);
|
||||
S->re.fill(0.f);
|
||||
S->im.fill(0.f);
|
||||
|
||||
rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
|
||||
render_buffer.GetFftBuffer();
|
||||
const size_t num_render_channels = render_buffer_data[0].size();
|
||||
const size_t lim1 = std::min(
|
||||
render_buffer_data.size() - render_buffer.Position(), num_partitions);
|
||||
const size_t lim2 = num_partitions;
|
||||
constexpr size_t kNumEightBinBands = kFftLengthBy2 / 8;
|
||||
|
||||
size_t X_partition = render_buffer.Position();
|
||||
size_t p = 0;
|
||||
size_t limit = lim1;
|
||||
do {
|
||||
for (; p < limit; ++p, ++X_partition) {
|
||||
for (size_t ch = 0; ch < num_render_channels; ++ch) {
|
||||
const FftData& H_p_ch = H[p][ch];
|
||||
const FftData& X = render_buffer_data[X_partition][ch];
|
||||
for (size_t k = 0, n = 0; n < kNumEightBinBands; ++n, k += 8) {
|
||||
const __m256 X_re = _mm256_loadu_ps(&X.re[k]);
|
||||
const __m256 X_im = _mm256_loadu_ps(&X.im[k]);
|
||||
const __m256 H_re = _mm256_loadu_ps(&H_p_ch.re[k]);
|
||||
const __m256 H_im = _mm256_loadu_ps(&H_p_ch.im[k]);
|
||||
const __m256 S_re = _mm256_loadu_ps(&S->re[k]);
|
||||
const __m256 S_im = _mm256_loadu_ps(&S->im[k]);
|
||||
const __m256 a = _mm256_mul_ps(X_re, H_re);
|
||||
const __m256 b = _mm256_mul_ps(X_im, H_im);
|
||||
const __m256 c = _mm256_mul_ps(X_re, H_im);
|
||||
const __m256 d = _mm256_mul_ps(X_im, H_re);
|
||||
const __m256 e = _mm256_sub_ps(a, b);
|
||||
const __m256 f = _mm256_add_ps(c, d);
|
||||
const __m256 g = _mm256_add_ps(S_re, e);
|
||||
const __m256 h = _mm256_add_ps(S_im, f);
|
||||
_mm256_storeu_ps(&S->re[k], g);
|
||||
_mm256_storeu_ps(&S->im[k], h);
|
||||
}
|
||||
}
|
||||
}
|
||||
limit = lim2;
|
||||
X_partition = 0;
|
||||
} while (p < lim2);
|
||||
|
||||
X_partition = render_buffer.Position();
|
||||
p = 0;
|
||||
limit = lim1;
|
||||
do {
|
||||
for (; p < limit; ++p, ++X_partition) {
|
||||
for (size_t ch = 0; ch < num_render_channels; ++ch) {
|
||||
const FftData& H_p_ch = H[p][ch];
|
||||
const FftData& X = render_buffer_data[X_partition][ch];
|
||||
S->re[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2] -
|
||||
X.im[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2];
|
||||
S->im[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2] +
|
||||
X.im[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2];
|
||||
}
|
||||
}
|
||||
limit = lim2;
|
||||
X_partition = 0;
|
||||
} while (p < lim2);
|
||||
}
|
||||
|
||||
} // namespace aec3
|
||||
} // namespace webrtc
|
||||
@ -85,10 +85,12 @@ void ComputeErl(const Aec3Optimization& optimization,
|
||||
case Aec3Optimization::kSse2:
|
||||
aec3::ErlComputer_SSE2(H2, erl);
|
||||
break;
|
||||
case Aec3Optimization::kAvx2:
|
||||
aec3::ErlComputer_AVX2(H2, erl);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Aec3Optimization::kNeon:
|
||||
|
||||
aec3::ErlComputer_NEON(H2, erl);
|
||||
break;
|
||||
#endif
|
||||
|
||||
@ -36,6 +36,10 @@ void ErlComputer_NEON(
|
||||
void ErlComputer_SSE2(
|
||||
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
|
||||
rtc::ArrayView<float> erl);
|
||||
|
||||
void ErlComputer_AVX2(
|
||||
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
|
||||
rtc::ArrayView<float> erl);
|
||||
#endif
|
||||
|
||||
} // namespace aec3
|
||||
|
||||
@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h"
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
namespace aec3 {
|
||||
|
||||
// Computes and stores the echo return loss estimate of the filter, which is the
|
||||
// sum of the partition frequency responses.
|
||||
void ErlComputer_AVX2(
|
||||
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
|
||||
rtc::ArrayView<float> erl) {
|
||||
std::fill(erl.begin(), erl.end(), 0.f);
|
||||
for (auto& H2_j : H2) {
|
||||
for (size_t k = 0; k < kFftLengthBy2; k += 8) {
|
||||
const __m256 H2_j_k = _mm256_loadu_ps(&H2_j[k]);
|
||||
__m256 erl_k = _mm256_loadu_ps(&erl[k]);
|
||||
erl_k = _mm256_add_ps(erl_k, H2_j_k);
|
||||
_mm256_storeu_ps(&erl[k], erl_k);
|
||||
}
|
||||
erl[kFftLengthBy2] += H2_j[kFftLengthBy2];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace aec3
|
||||
} // namespace webrtc
|
||||
@ -75,6 +75,31 @@ TEST(AdaptiveFirFilter, UpdateErlSse2Optimization) {
|
||||
}
|
||||
}
|
||||
|
||||
// Verifies that the optimized method for echo return loss computation is
|
||||
// bitexact to the reference counterpart.
|
||||
TEST(AdaptiveFirFilter, UpdateErlAvx2Optimization) {
|
||||
bool use_avx2 = (WebRtc_GetCPUInfo(kAVX2) != 0);
|
||||
if (use_avx2) {
|
||||
const size_t kNumPartitions = 12;
|
||||
std::vector<std::array<float, kFftLengthBy2Plus1>> H2(kNumPartitions);
|
||||
std::array<float, kFftLengthBy2Plus1> erl;
|
||||
std::array<float, kFftLengthBy2Plus1> erl_AVX2;
|
||||
|
||||
for (size_t j = 0; j < H2.size(); ++j) {
|
||||
for (size_t k = 0; k < H2[j].size(); ++k) {
|
||||
H2[j][k] = k + j / 3.f;
|
||||
}
|
||||
}
|
||||
|
||||
ErlComputer(H2, erl);
|
||||
ErlComputer_AVX2(H2, erl_AVX2);
|
||||
|
||||
for (size_t j = 0; j < erl.size(); ++j) {
|
||||
EXPECT_FLOAT_EQ(erl[j], erl_AVX2[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace aec3
|
||||
|
||||
@ -246,6 +246,81 @@ TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels,
|
||||
}
|
||||
}
|
||||
|
||||
// Verifies that the optimized methods for filter adaptation are bitexact to
|
||||
// their reference counterparts.
|
||||
TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels,
|
||||
FilterAdaptationAvx2Optimizations) {
|
||||
const size_t num_render_channels = GetParam();
|
||||
constexpr int kSampleRateHz = 48000;
|
||||
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
|
||||
|
||||
bool use_avx2 = (WebRtc_GetCPUInfo(kAVX2) != 0);
|
||||
if (use_avx2) {
|
||||
for (size_t num_partitions : {2, 5, 12, 30, 50}) {
|
||||
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
|
||||
RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz,
|
||||
num_render_channels));
|
||||
Random random_generator(42U);
|
||||
std::vector<std::vector<std::vector<float>>> x(
|
||||
kNumBands,
|
||||
std::vector<std::vector<float>>(num_render_channels,
|
||||
std::vector<float>(kBlockSize, 0.f)));
|
||||
FftData S_C;
|
||||
FftData S_Avx2;
|
||||
FftData G;
|
||||
Aec3Fft fft;
|
||||
std::vector<std::vector<FftData>> H_C(
|
||||
num_partitions, std::vector<FftData>(num_render_channels));
|
||||
std::vector<std::vector<FftData>> H_Avx2(
|
||||
num_partitions, std::vector<FftData>(num_render_channels));
|
||||
for (size_t p = 0; p < num_partitions; ++p) {
|
||||
for (size_t ch = 0; ch < num_render_channels; ++ch) {
|
||||
H_C[p][ch].Clear();
|
||||
H_Avx2[p][ch].Clear();
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t k = 0; k < 500; ++k) {
|
||||
for (size_t band = 0; band < x.size(); ++band) {
|
||||
for (size_t ch = 0; ch < x[band].size(); ++ch) {
|
||||
RandomizeSampleVector(&random_generator, x[band][ch]);
|
||||
}
|
||||
}
|
||||
render_delay_buffer->Insert(x);
|
||||
if (k == 0) {
|
||||
render_delay_buffer->Reset();
|
||||
}
|
||||
render_delay_buffer->PrepareCaptureProcessing();
|
||||
auto* const render_buffer = render_delay_buffer->GetRenderBuffer();
|
||||
|
||||
ApplyFilter_Avx2(*render_buffer, num_partitions, H_Avx2, &S_Avx2);
|
||||
ApplyFilter(*render_buffer, num_partitions, H_C, &S_C);
|
||||
for (size_t j = 0; j < S_C.re.size(); ++j) {
|
||||
EXPECT_FLOAT_EQ(S_C.re[j], S_Avx2.re[j]);
|
||||
EXPECT_FLOAT_EQ(S_C.im[j], S_Avx2.im[j]);
|
||||
}
|
||||
|
||||
std::for_each(G.re.begin(), G.re.end(),
|
||||
[&](float& a) { a = random_generator.Rand<float>(); });
|
||||
std::for_each(G.im.begin(), G.im.end(),
|
||||
[&](float& a) { a = random_generator.Rand<float>(); });
|
||||
|
||||
AdaptPartitions_Avx2(*render_buffer, G, num_partitions, &H_Avx2);
|
||||
AdaptPartitions(*render_buffer, G, num_partitions, &H_C);
|
||||
|
||||
for (size_t p = 0; p < num_partitions; ++p) {
|
||||
for (size_t ch = 0; ch < num_render_channels; ++ch) {
|
||||
for (size_t j = 0; j < H_C[p][ch].re.size(); ++j) {
|
||||
EXPECT_FLOAT_EQ(H_C[p][ch].re[j], H_Avx2[p][ch].re[j]);
|
||||
EXPECT_FLOAT_EQ(H_C[p][ch].im[j], H_Avx2[p][ch].im[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verifies that the optimized method for frequency response computation is
|
||||
// bitexact to the reference counterpart.
|
||||
TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels,
|
||||
@ -281,6 +356,41 @@ TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels,
|
||||
}
|
||||
}
|
||||
|
||||
// Verifies that the optimized method for frequency response computation is
|
||||
// bitexact to the reference counterpart.
|
||||
TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels,
|
||||
ComputeFrequencyResponseAvx2Optimization) {
|
||||
const size_t num_render_channels = GetParam();
|
||||
bool use_avx2 = (WebRtc_GetCPUInfo(kAVX2) != 0);
|
||||
if (use_avx2) {
|
||||
for (size_t num_partitions : {2, 5, 12, 30, 50}) {
|
||||
std::vector<std::vector<FftData>> H(
|
||||
num_partitions, std::vector<FftData>(num_render_channels));
|
||||
std::vector<std::array<float, kFftLengthBy2Plus1>> H2(num_partitions);
|
||||
std::vector<std::array<float, kFftLengthBy2Plus1>> H2_Avx2(
|
||||
num_partitions);
|
||||
|
||||
for (size_t p = 0; p < num_partitions; ++p) {
|
||||
for (size_t ch = 0; ch < num_render_channels; ++ch) {
|
||||
for (size_t k = 0; k < H[p][ch].re.size(); ++k) {
|
||||
H[p][ch].re[k] = k + p / 3.f + ch;
|
||||
H[p][ch].im[k] = p + k / 7.f - ch;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ComputeFrequencyResponse(num_partitions, H, &H2);
|
||||
ComputeFrequencyResponse_Avx2(num_partitions, H, &H2_Avx2);
|
||||
|
||||
for (size_t p = 0; p < num_partitions; ++p) {
|
||||
for (size_t k = 0; k < H2[p].size(); ++k) {
|
||||
EXPECT_FLOAT_EQ(H2[p][k], H2_Avx2[p][k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
|
||||
|
||||
@ -20,7 +20,9 @@ namespace webrtc {
|
||||
|
||||
Aec3Optimization DetectOptimization() {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
if (WebRtc_GetCPUInfo(kSSE2) != 0) {
|
||||
if (WebRtc_GetCPUInfo(kAVX2) != 0) {
|
||||
return Aec3Optimization::kAvx2;
|
||||
} else if (WebRtc_GetCPUInfo(kSSE2) != 0) {
|
||||
return Aec3Optimization::kSse2;
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -23,7 +23,7 @@ namespace webrtc {
|
||||
#define ALIGN16_END __attribute__((aligned(16)))
|
||||
#endif
|
||||
|
||||
enum class Aec3Optimization { kNone, kSse2, kNeon };
|
||||
enum class Aec3Optimization { kNone, kSse2, kAvx2, kNeon };
|
||||
|
||||
constexpr int kNumBlocksPerSecond = 250;
|
||||
|
||||
|
||||
@ -40,6 +40,9 @@ struct FftData {
|
||||
im.fill(0.f);
|
||||
}
|
||||
|
||||
// Computes the power spectrum of the data.
|
||||
void SpectrumAVX2(rtc::ArrayView<float> power_spectrum) const;
|
||||
|
||||
// Computes the power spectrum of the data.
|
||||
void Spectrum(Aec3Optimization optimization,
|
||||
rtc::ArrayView<float> power_spectrum) const {
|
||||
@ -60,6 +63,9 @@ struct FftData {
|
||||
power_spectrum[kFftLengthBy2] = re[kFftLengthBy2] * re[kFftLengthBy2] +
|
||||
im[kFftLengthBy2] * im[kFftLengthBy2];
|
||||
} break;
|
||||
case Aec3Optimization::kAvx2:
|
||||
SpectrumAVX2(power_spectrum);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
std::transform(re.begin(), re.end(), im.begin(), power_spectrum.begin(),
|
||||
|
||||
33
modules/audio_processing/aec3/fft_data_avx2.cc
Normal file
33
modules/audio_processing/aec3/fft_data_avx2.cc
Normal file
@ -0,0 +1,33 @@
|
||||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/aec3/fft_data.h"
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
#include "api/array_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
// Computes the power spectrum of the data.
|
||||
void FftData::SpectrumAVX2(rtc::ArrayView<float> power_spectrum) const {
|
||||
RTC_DCHECK_EQ(kFftLengthBy2Plus1, power_spectrum.size());
|
||||
for (size_t k = 0; k < kFftLengthBy2; k += 8) {
|
||||
__m256 r = _mm256_loadu_ps(&re[k]);
|
||||
__m256 i = _mm256_loadu_ps(&im[k]);
|
||||
__m256 ii = _mm256_mul_ps(i, i);
|
||||
ii = _mm256_fmadd_ps(r, r, ii);
|
||||
_mm256_storeu_ps(&power_spectrum[k], ii);
|
||||
}
|
||||
power_spectrum[kFftLengthBy2] = re[kFftLengthBy2] * re[kFftLengthBy2] +
|
||||
im[kFftLengthBy2] * im[kFftLengthBy2];
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
@ -19,7 +19,7 @@ namespace webrtc {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
// Verifies that the optimized methods are bitexact to their reference
|
||||
// counterparts.
|
||||
TEST(FftData, TestOptimizations) {
|
||||
TEST(FftData, TestSse2Optimizations) {
|
||||
if (WebRtc_GetCPUInfo(kSSE2) != 0) {
|
||||
FftData x;
|
||||
|
||||
@ -39,6 +39,29 @@ TEST(FftData, TestOptimizations) {
|
||||
EXPECT_EQ(spectrum, spectrum_sse2);
|
||||
}
|
||||
}
|
||||
|
||||
// Verifies that the optimized methods are bitexact to their reference
|
||||
// counterparts.
|
||||
TEST(FftData, TestAvx2Optimizations) {
|
||||
if (WebRtc_GetCPUInfo(kAVX2) != 0) {
|
||||
FftData x;
|
||||
|
||||
for (size_t k = 0; k < x.re.size(); ++k) {
|
||||
x.re[k] = k + 1;
|
||||
}
|
||||
|
||||
x.im[0] = x.im[x.im.size() - 1] = 0.f;
|
||||
for (size_t k = 1; k < x.im.size() - 1; ++k) {
|
||||
x.im[k] = 2.f * (k + 1);
|
||||
}
|
||||
|
||||
std::array<float, kFftLengthBy2Plus1> spectrum;
|
||||
std::array<float, kFftLengthBy2Plus1> spectrum_avx2;
|
||||
x.Spectrum(Aec3Optimization::kNone, spectrum);
|
||||
x.Spectrum(Aec3Optimization::kAvx2, spectrum_avx2);
|
||||
EXPECT_EQ(spectrum, spectrum_avx2);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
|
||||
|
||||
@ -364,6 +364,11 @@ void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
|
||||
smoothing_, render_buffer.buffer, y,
|
||||
filters_[n], &filters_updated, &error_sum);
|
||||
break;
|
||||
case Aec3Optimization::kAvx2:
|
||||
aec3::MatchedFilterCore_AVX2(x_start_index, x2_sum_threshold,
|
||||
smoothing_, render_buffer.buffer, y,
|
||||
filters_[n], &filters_updated, &error_sum);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Aec3Optimization::kNeon:
|
||||
|
||||
@ -53,6 +53,16 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
|
||||
bool* filters_updated,
|
||||
float* error_sum);
|
||||
|
||||
// Filter core for the matched filter that is optimized for AVX2.
|
||||
void MatchedFilterCore_AVX2(size_t x_start_index,
|
||||
float x2_sum_threshold,
|
||||
float smoothing,
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float> h,
|
||||
bool* filters_updated,
|
||||
float* error_sum);
|
||||
|
||||
#endif
|
||||
|
||||
// Filter core for the matched filter.
|
||||
|
||||
132
modules/audio_processing/aec3/matched_filter_avx2.cc
Normal file
132
modules/audio_processing/aec3/matched_filter_avx2.cc
Normal file
@ -0,0 +1,132 @@
|
||||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/aec3/matched_filter.h"
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace aec3 {
|
||||
|
||||
void MatchedFilterCore_AVX2(size_t x_start_index,
|
||||
float x2_sum_threshold,
|
||||
float smoothing,
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float> h,
|
||||
bool* filters_updated,
|
||||
float* error_sum) {
|
||||
const int h_size = static_cast<int>(h.size());
|
||||
const int x_size = static_cast<int>(x.size());
|
||||
RTC_DCHECK_EQ(0, h_size % 8);
|
||||
|
||||
// Process for all samples in the sub-block.
|
||||
for (size_t i = 0; i < y.size(); ++i) {
|
||||
// Apply the matched filter as filter * x, and compute x * x.
|
||||
|
||||
RTC_DCHECK_GT(x_size, x_start_index);
|
||||
const float* x_p = &x[x_start_index];
|
||||
const float* h_p = &h[0];
|
||||
|
||||
// Initialize values for the accumulation.
|
||||
__m256 s_256 = _mm256_set1_ps(0);
|
||||
__m256 x2_sum_256 = _mm256_set1_ps(0);
|
||||
float x2_sum = 0.f;
|
||||
float s = 0;
|
||||
|
||||
// Compute loop chunk sizes until, and after, the wraparound of the circular
|
||||
// buffer for x.
|
||||
const int chunk1 =
|
||||
std::min(h_size, static_cast<int>(x_size - x_start_index));
|
||||
|
||||
// Perform the loop in two chunks.
|
||||
const int chunk2 = h_size - chunk1;
|
||||
for (int limit : {chunk1, chunk2}) {
|
||||
// Perform 256 bit vector operations.
|
||||
const int limit_by_8 = limit >> 3;
|
||||
for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) {
|
||||
// Load the data into 256 bit vectors.
|
||||
__m256 x_k = _mm256_loadu_ps(x_p);
|
||||
__m256 h_k = _mm256_loadu_ps(h_p);
|
||||
// Compute and accumulate x * x and h * x.
|
||||
x2_sum_256 = _mm256_fmadd_ps(x_k, x_k, x2_sum_256);
|
||||
s_256 = _mm256_fmadd_ps(h_k, x_k, s_256);
|
||||
}
|
||||
|
||||
// Perform non-vector operations for any remaining items.
|
||||
for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) {
|
||||
const float x_k = *x_p;
|
||||
x2_sum += x_k * x_k;
|
||||
s += *h_p * x_k;
|
||||
}
|
||||
|
||||
x_p = &x[0];
|
||||
}
|
||||
|
||||
// Sum components together.
|
||||
__m128 x2_sum_128 = _mm_add_ps(_mm256_extractf128_ps(x2_sum_256, 0),
|
||||
_mm256_extractf128_ps(x2_sum_256, 1));
|
||||
__m128 s_128 = _mm_add_ps(_mm256_extractf128_ps(s_256, 0),
|
||||
_mm256_extractf128_ps(s_256, 1));
|
||||
// Combine the accumulated vector and scalar values.
|
||||
float* v = reinterpret_cast<float*>(&x2_sum_128);
|
||||
x2_sum += v[0] + v[1] + v[2] + v[3];
|
||||
v = reinterpret_cast<float*>(&s_128);
|
||||
s += v[0] + v[1] + v[2] + v[3];
|
||||
|
||||
// Compute the matched filter error.
|
||||
float e = y[i] - s;
|
||||
const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
|
||||
(*error_sum) += e * e;
|
||||
|
||||
// Update the matched filter estimate in an NLMS manner.
|
||||
if (x2_sum > x2_sum_threshold && !saturation) {
|
||||
RTC_DCHECK_LT(0.f, x2_sum);
|
||||
const float alpha = smoothing * e / x2_sum;
|
||||
const __m256 alpha_256 = _mm256_set1_ps(alpha);
|
||||
|
||||
// filter = filter + smoothing * (y - filter * x) * x / x * x.
|
||||
float* h_p = &h[0];
|
||||
x_p = &x[x_start_index];
|
||||
|
||||
// Perform the loop in two chunks.
|
||||
for (int limit : {chunk1, chunk2}) {
|
||||
// Perform 256 bit vector operations.
|
||||
const int limit_by_8 = limit >> 3;
|
||||
for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) {
|
||||
// Load the data into 256 bit vectors.
|
||||
__m256 h_k = _mm256_loadu_ps(h_p);
|
||||
__m256 x_k = _mm256_loadu_ps(x_p);
|
||||
// Compute h = h + alpha * x.
|
||||
h_k = _mm256_fmadd_ps(x_k, alpha_256, h_k);
|
||||
|
||||
// Store the result.
|
||||
_mm256_storeu_ps(h_p, h_k);
|
||||
}
|
||||
|
||||
// Perform non-vector operations for any remaining items.
|
||||
for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) {
|
||||
*h_p += alpha * *x_p;
|
||||
}
|
||||
|
||||
x_p = &x[0];
|
||||
}
|
||||
|
||||
*filters_updated = true;
|
||||
}
|
||||
|
||||
x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace aec3
|
||||
} // namespace webrtc
|
||||
@ -133,6 +133,47 @@ TEST(MatchedFilter, TestSse2Optimizations) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MatchedFilter, TestAvx2Optimizations) {
|
||||
bool use_avx2 = (WebRtc_GetCPUInfo(kAVX2) != 0);
|
||||
if (use_avx2) {
|
||||
Random random_generator(42U);
|
||||
constexpr float kSmoothing = 0.7f;
|
||||
for (auto down_sampling_factor : kDownSamplingFactors) {
|
||||
const size_t sub_block_size = kBlockSize / down_sampling_factor;
|
||||
std::vector<float> x(2000);
|
||||
RandomizeSampleVector(&random_generator, x);
|
||||
std::vector<float> y(sub_block_size);
|
||||
std::vector<float> h_AVX2(512);
|
||||
std::vector<float> h(512);
|
||||
int x_index = 0;
|
||||
for (int k = 0; k < 1000; ++k) {
|
||||
RandomizeSampleVector(&random_generator, y);
|
||||
|
||||
bool filters_updated = false;
|
||||
float error_sum = 0.f;
|
||||
bool filters_updated_AVX2 = false;
|
||||
float error_sum_AVX2 = 0.f;
|
||||
|
||||
MatchedFilterCore_AVX2(x_index, h.size() * 150.f * 150.f, kSmoothing, x,
|
||||
y, h_AVX2, &filters_updated_AVX2,
|
||||
&error_sum_AVX2);
|
||||
|
||||
MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y,
|
||||
h, &filters_updated, &error_sum);
|
||||
|
||||
EXPECT_EQ(filters_updated, filters_updated_AVX2);
|
||||
EXPECT_NEAR(error_sum, error_sum_AVX2, error_sum / 100000.f);
|
||||
|
||||
for (size_t j = 0; j < h.size(); ++j) {
|
||||
EXPECT_NEAR(h[j], h_AVX2[j], 0.00001f);
|
||||
}
|
||||
|
||||
x_index = (x_index + sub_block_size) % x.size();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// Verifies that the matched filter produces proper lag estimates for
|
||||
|
||||
@ -40,6 +40,7 @@ class VectorMath {
|
||||
: optimization_(optimization) {}
|
||||
|
||||
// Elementwise square root.
|
||||
void SqrtAVX2(rtc::ArrayView<float> x);
|
||||
void Sqrt(rtc::ArrayView<float> x) {
|
||||
switch (optimization_) {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
@ -58,6 +59,9 @@ class VectorMath {
|
||||
x[j] = sqrtf(x[j]);
|
||||
}
|
||||
} break;
|
||||
case Aec3Optimization::kAvx2:
|
||||
SqrtAVX2(x);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Aec3Optimization::kNeon: {
|
||||
@ -110,6 +114,9 @@ class VectorMath {
|
||||
}
|
||||
|
||||
// Elementwise vector multiplication z = x * y.
|
||||
void MultiplyAVX2(rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float> z);
|
||||
void Multiply(rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float> z) {
|
||||
@ -133,6 +140,9 @@ class VectorMath {
|
||||
z[j] = x[j] * y[j];
|
||||
}
|
||||
} break;
|
||||
case Aec3Optimization::kAvx2:
|
||||
MultiplyAVX2(x, y, z);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Aec3Optimization::kNeon: {
|
||||
@ -159,6 +169,7 @@ class VectorMath {
|
||||
}
|
||||
|
||||
// Elementwise vector accumulation z += x.
|
||||
void AccumulateAVX2(rtc::ArrayView<const float> x, rtc::ArrayView<float> z);
|
||||
void Accumulate(rtc::ArrayView<const float> x, rtc::ArrayView<float> z) {
|
||||
RTC_DCHECK_EQ(z.size(), x.size());
|
||||
switch (optimization_) {
|
||||
@ -179,6 +190,9 @@ class VectorMath {
|
||||
z[j] += x[j];
|
||||
}
|
||||
} break;
|
||||
case Aec3Optimization::kAvx2:
|
||||
AccumulateAVX2(x, z);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Aec3Optimization::kNeon: {
|
||||
|
||||
82
modules/audio_processing/aec3/vector_math_avx2.cc
Normal file
82
modules/audio_processing/aec3/vector_math_avx2.cc
Normal file
@ -0,0 +1,82 @@
|
||||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/aec3/vector_math.h"
|
||||
|
||||
#include <immintrin.h>
|
||||
#include <math.h>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace aec3 {
|
||||
|
||||
// Elementwise square root.
|
||||
void VectorMath::SqrtAVX2(rtc::ArrayView<float> x) {
|
||||
const int x_size = static_cast<int>(x.size());
|
||||
const int vector_limit = x_size >> 3;
|
||||
|
||||
int j = 0;
|
||||
for (; j < vector_limit * 8; j += 8) {
|
||||
__m256 g = _mm256_loadu_ps(&x[j]);
|
||||
g = _mm256_sqrt_ps(g);
|
||||
_mm256_storeu_ps(&x[j], g);
|
||||
}
|
||||
|
||||
for (; j < x_size; ++j) {
|
||||
x[j] = sqrtf(x[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// Elementwise vector multiplication z = x * y.
|
||||
void VectorMath::MultiplyAVX2(rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float> z) {
|
||||
RTC_DCHECK_EQ(z.size(), x.size());
|
||||
RTC_DCHECK_EQ(z.size(), y.size());
|
||||
const int x_size = static_cast<int>(x.size());
|
||||
const int vector_limit = x_size >> 3;
|
||||
|
||||
int j = 0;
|
||||
for (; j < vector_limit * 8; j += 8) {
|
||||
const __m256 x_j = _mm256_loadu_ps(&x[j]);
|
||||
const __m256 y_j = _mm256_loadu_ps(&y[j]);
|
||||
const __m256 z_j = _mm256_mul_ps(x_j, y_j);
|
||||
_mm256_storeu_ps(&z[j], z_j);
|
||||
}
|
||||
|
||||
for (; j < x_size; ++j) {
|
||||
z[j] = x[j] * y[j];
|
||||
}
|
||||
}
|
||||
|
||||
// Elementwise vector accumulation z += x.
|
||||
void VectorMath::AccumulateAVX2(rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float> z) {
|
||||
RTC_DCHECK_EQ(z.size(), x.size());
|
||||
const int x_size = static_cast<int>(x.size());
|
||||
const int vector_limit = x_size >> 3;
|
||||
|
||||
int j = 0;
|
||||
for (; j < vector_limit * 8; j += 8) {
|
||||
const __m256 x_j = _mm256_loadu_ps(&x[j]);
|
||||
__m256 z_j = _mm256_loadu_ps(&z[j]);
|
||||
z_j = _mm256_add_ps(x_j, z_j);
|
||||
_mm256_storeu_ps(&z[j], z_j);
|
||||
}
|
||||
|
||||
for (; j < x_size; ++j) {
|
||||
z[j] += x[j];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace aec3
|
||||
} // namespace webrtc
|
||||
@ -79,7 +79,7 @@ TEST(VectorMath, Accumulate) {
|
||||
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
|
||||
TEST(VectorMath, Sqrt) {
|
||||
TEST(VectorMath, Sse2Sqrt) {
|
||||
if (WebRtc_GetCPUInfo(kSSE2) != 0) {
|
||||
std::array<float, kFftLengthBy2Plus1> x;
|
||||
std::array<float, kFftLengthBy2Plus1> z;
|
||||
@ -101,7 +101,29 @@ TEST(VectorMath, Sqrt) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(VectorMath, Multiply) {
|
||||
TEST(VectorMath, Avx2Sqrt) {
|
||||
if (WebRtc_GetCPUInfo(kAVX2) != 0) {
|
||||
std::array<float, kFftLengthBy2Plus1> x;
|
||||
std::array<float, kFftLengthBy2Plus1> z;
|
||||
std::array<float, kFftLengthBy2Plus1> z_avx2;
|
||||
|
||||
for (size_t k = 0; k < x.size(); ++k) {
|
||||
x[k] = (2.f / 3.f) * k;
|
||||
}
|
||||
|
||||
std::copy(x.begin(), x.end(), z.begin());
|
||||
aec3::VectorMath(Aec3Optimization::kNone).Sqrt(z);
|
||||
std::copy(x.begin(), x.end(), z_avx2.begin());
|
||||
aec3::VectorMath(Aec3Optimization::kAvx2).Sqrt(z_avx2);
|
||||
EXPECT_EQ(z, z_avx2);
|
||||
for (size_t k = 0; k < z.size(); ++k) {
|
||||
EXPECT_FLOAT_EQ(z[k], z_avx2[k]);
|
||||
EXPECT_FLOAT_EQ(sqrtf(x[k]), z_avx2[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(VectorMath, Sse2Multiply) {
|
||||
if (WebRtc_GetCPUInfo(kSSE2) != 0) {
|
||||
std::array<float, kFftLengthBy2Plus1> x;
|
||||
std::array<float, kFftLengthBy2Plus1> y;
|
||||
@ -122,7 +144,28 @@ TEST(VectorMath, Multiply) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(VectorMath, Accumulate) {
|
||||
TEST(VectorMath, Avx2Multiply) {
|
||||
if (WebRtc_GetCPUInfo(kAVX2) != 0) {
|
||||
std::array<float, kFftLengthBy2Plus1> x;
|
||||
std::array<float, kFftLengthBy2Plus1> y;
|
||||
std::array<float, kFftLengthBy2Plus1> z;
|
||||
std::array<float, kFftLengthBy2Plus1> z_avx2;
|
||||
|
||||
for (size_t k = 0; k < x.size(); ++k) {
|
||||
x[k] = k;
|
||||
y[k] = (2.f / 3.f) * k;
|
||||
}
|
||||
|
||||
aec3::VectorMath(Aec3Optimization::kNone).Multiply(x, y, z);
|
||||
aec3::VectorMath(Aec3Optimization::kAvx2).Multiply(x, y, z_avx2);
|
||||
for (size_t k = 0; k < z.size(); ++k) {
|
||||
EXPECT_FLOAT_EQ(z[k], z_avx2[k]);
|
||||
EXPECT_FLOAT_EQ(x[k] * y[k], z_avx2[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(VectorMath, Sse2Accumulate) {
|
||||
if (WebRtc_GetCPUInfo(kSSE2) != 0) {
|
||||
std::array<float, kFftLengthBy2Plus1> x;
|
||||
std::array<float, kFftLengthBy2Plus1> z;
|
||||
@ -141,6 +184,26 @@ TEST(VectorMath, Accumulate) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(VectorMath, Avx2Accumulate) {
|
||||
if (WebRtc_GetCPUInfo(kAVX2) != 0) {
|
||||
std::array<float, kFftLengthBy2Plus1> x;
|
||||
std::array<float, kFftLengthBy2Plus1> z;
|
||||
std::array<float, kFftLengthBy2Plus1> z_avx2;
|
||||
|
||||
for (size_t k = 0; k < x.size(); ++k) {
|
||||
x[k] = k;
|
||||
z[k] = z_avx2[k] = 2.f * k;
|
||||
}
|
||||
|
||||
aec3::VectorMath(Aec3Optimization::kNone).Accumulate(x, z);
|
||||
aec3::VectorMath(Aec3Optimization::kAvx2).Accumulate(x, z_avx2);
|
||||
for (size_t k = 0; k < z.size(); ++k) {
|
||||
EXPECT_FLOAT_EQ(z[k], z_avx2[k]);
|
||||
EXPECT_FLOAT_EQ(x[k] + 2.f * x[k], z_avx2[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user