From e537e9ca13e448bcb690c0bb173e8a7a1117bcce Mon Sep 17 00:00:00 2001 From: Zhaoliang Ma Date: Mon, 31 Aug 2020 10:20:47 +0800 Subject: [PATCH] aec3: Support AVX2/FMA intrinsics in AEC3 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug: webrtc:11663 Change-Id: Ib75eb616ef0cb62698b0d96af7ebe42e93825222 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/179006 Reviewed-by: Mirko Bonadei Reviewed-by: Sam Zackrisson Reviewed-by: Per Ã…hgren Commit-Queue: Mirko Bonadei Cr-Commit-Position: refs/heads/master@{#32023} --- modules/audio_processing/aec3/BUILD.gn | 149 ++++++++++++-- .../aec3/adaptive_fir_filter.cc | 10 + .../aec3/adaptive_fir_filter.h | 15 ++ .../aec3/adaptive_fir_filter_avx2.cc | 187 ++++++++++++++++++ .../aec3/adaptive_fir_filter_erl.cc | 4 +- .../aec3/adaptive_fir_filter_erl.h | 4 + .../aec3/adaptive_fir_filter_erl_avx2.cc | 37 ++++ .../aec3/adaptive_fir_filter_erl_unittest.cc | 25 +++ .../aec3/adaptive_fir_filter_unittest.cc | 110 +++++++++++ modules/audio_processing/aec3/aec3_common.cc | 4 +- modules/audio_processing/aec3/aec3_common.h | 2 +- modules/audio_processing/aec3/fft_data.h | 6 + .../audio_processing/aec3/fft_data_avx2.cc | 33 ++++ .../aec3/fft_data_unittest.cc | 25 ++- .../audio_processing/aec3/matched_filter.cc | 5 + .../audio_processing/aec3/matched_filter.h | 10 + .../aec3/matched_filter_avx2.cc | 132 +++++++++++++ .../aec3/matched_filter_unittest.cc | 41 ++++ modules/audio_processing/aec3/vector_math.h | 14 ++ .../audio_processing/aec3/vector_math_avx2.cc | 82 ++++++++ .../aec3/vector_math_unittest.cc | 69 ++++++- 21 files changed, 945 insertions(+), 19 deletions(-) create mode 100644 modules/audio_processing/aec3/adaptive_fir_filter_avx2.cc create mode 100644 modules/audio_processing/aec3/adaptive_fir_filter_erl_avx2.cc create mode 100644 modules/audio_processing/aec3/fft_data_avx2.cc create mode 100644 modules/audio_processing/aec3/matched_filter_avx2.cc create mode 100644 modules/audio_processing/aec3/vector_math_avx2.cc diff --git a/modules/audio_processing/aec3/BUILD.gn b/modules/audio_processing/aec3/BUILD.gn index 507f2bc8bd..c312b0ebd2 100644 --- a/modules/audio_processing/aec3/BUILD.gn +++ b/modules/audio_processing/aec3/BUILD.gn @@ -13,13 +13,9 @@ rtc_library("aec3") { configs += [ "..:apm_debug_dump" ] sources = [ "adaptive_fir_filter.cc", - "adaptive_fir_filter.h", "adaptive_fir_filter_erl.cc", - "adaptive_fir_filter_erl.h", "aec3_common.cc", - "aec3_common.h", "aec3_fft.cc", - "aec3_fft.h", "aec_state.cc", "aec_state.h", "alignment_mixer.cc", @@ -27,7 +23,6 @@ rtc_library("aec3") { "api_call_jitter_metrics.cc", "api_call_jitter_metrics.h", "block_buffer.cc", - "block_buffer.h", "block_delay_buffer.cc", "block_delay_buffer.h", "block_framer.cc", @@ -66,8 +61,6 @@ rtc_library("aec3") { "erle_estimator.cc", "erle_estimator.h", "fft_buffer.cc", - "fft_buffer.h", - "fft_data.h", "filter_analyzer.cc", "filter_analyzer.h", "frame_blocker.cc", @@ -75,7 +68,6 @@ rtc_library("aec3") { "fullband_erle_estimator.cc", "fullband_erle_estimator.h", "matched_filter.cc", - "matched_filter.h", "matched_filter_lag_aggregator.cc", "matched_filter_lag_aggregator.h", "moving_average.cc", @@ -84,7 +76,6 @@ rtc_library("aec3") { "refined_filter_update_gain.cc", "refined_filter_update_gain.h", "render_buffer.cc", - "render_buffer.h", "render_delay_buffer.cc", "render_delay_buffer.h", "render_delay_controller.cc", @@ -106,7 +97,6 @@ rtc_library("aec3") { "signal_dependent_erle_estimator.cc", "signal_dependent_erle_estimator.h", "spectrum_buffer.cc", - "spectrum_buffer.h", "stationarity_estimator.cc", "stationarity_estimator.h", "subband_erle_estimator.cc", @@ -123,7 +113,6 @@ rtc_library("aec3") { "suppression_filter.h", "suppression_gain.cc", "suppression_gain.h", - "vector_math.h", ] defines = [] @@ -133,6 +122,14 @@ rtc_library("aec3") { } deps = [ + ":adaptive_fir_filter", + ":adaptive_fir_filter_erl", + ":aec3_common", + ":aec3_fft", + ":fft_data", + ":matched_filter", + ":render_buffer", + ":vector_math", "..:apm_logging", "..:audio_buffer", "..:high_pass_filter", @@ -140,7 +137,6 @@ rtc_library("aec3") { "../../../api/audio:aec3_config", "../../../api/audio:echo_control", "../../../common_audio:common_audio_c", - "../../../common_audio/third_party/ooura:fft_size_128", "../../../rtc_base:checks", "../../../rtc_base:rtc_base_approved", "../../../rtc_base:safe_minmax", @@ -152,6 +148,127 @@ rtc_library("aec3") { "../utility:cascaded_biquad_filter", ] absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + + if (current_cpu == "x86" || current_cpu == "x64") { + deps += [ ":aec3_avx2" ] + } +} + +rtc_source_set("aec3_common") { + sources = [ "aec3_common.h" ] +} + +rtc_source_set("aec3_fft") { + sources = [ "aec3_fft.h" ] + deps = [ + ":aec3_common", + ":fft_data", + "../../../api:array_view", + "../../../common_audio/third_party/ooura:fft_size_128", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../../../rtc_base/system:arch", + ] +} + +rtc_source_set("render_buffer") { + sources = [ + "block_buffer.h", + "fft_buffer.h", + "render_buffer.h", + "spectrum_buffer.h", + ] + deps = [ + ":aec3_common", + ":fft_data", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../../../rtc_base/system:arch", + ] +} + +rtc_source_set("adaptive_fir_filter") { + sources = [ "adaptive_fir_filter.h" ] + deps = [ + ":aec3_common", + ":aec3_fft", + ":fft_data", + ":render_buffer", + "..:apm_logging", + "../../../api:array_view", + "../../../rtc_base/system:arch", + ] +} + +rtc_source_set("adaptive_fir_filter_erl") { + sources = [ "adaptive_fir_filter_erl.h" ] + deps = [ + ":aec3_common", + "../../../api:array_view", + "../../../rtc_base/system:arch", + ] +} + +rtc_source_set("matched_filter") { + sources = [ "matched_filter.h" ] + deps = [ + ":aec3_common", + "../../../api:array_view", + "../../../rtc_base:rtc_base_approved", + "../../../rtc_base/system:arch", + ] +} + +rtc_source_set("vector_math") { + sources = [ "vector_math.h" ] + deps = [ + ":aec3_common", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base/system:arch", + ] +} + +rtc_source_set("fft_data") { + sources = [ "fft_data.h" ] + deps = [ + ":aec3_common", + "../../../api:array_view", + "../../../rtc_base/system:arch", + ] +} + +if (current_cpu == "x86" || current_cpu == "x64") { + rtc_library("aec3_avx2") { + configs += [ "..:apm_debug_dump" ] + sources = [ + "adaptive_fir_filter_avx2.cc", + "adaptive_fir_filter_erl_avx2.cc", + "fft_data_avx2.cc", + "matched_filter_avx2.cc", + "vector_math_avx2.cc", + ] + + if (is_win) { + cflags = [ "/arch:AVX2" ] + } else { + cflags = [ + "-mavx2", + "-mfma", + ] + } + + deps = [ + ":adaptive_fir_filter", + ":adaptive_fir_filter_erl", + ":fft_data", + ":matched_filter", + ":vector_math", + "../../../api:array_view", + "../../../rtc_base:checks", + ] + } } if (rtc_include_tests) { @@ -171,7 +288,15 @@ if (rtc_include_tests) { ] deps = [ + ":adaptive_fir_filter", + ":adaptive_fir_filter_erl", ":aec3", + ":aec3_common", + ":aec3_fft", + ":fft_data", + ":matched_filter", + ":render_buffer", + ":vector_math", "..:apm_logging", "..:audio_buffer", "..:audio_processing", diff --git a/modules/audio_processing/aec3/adaptive_fir_filter.cc b/modules/audio_processing/aec3/adaptive_fir_filter.cc index 6a0f531663..bf3a7809f4 100644 --- a/modules/audio_processing/aec3/adaptive_fir_filter.cc +++ b/modules/audio_processing/aec3/adaptive_fir_filter.cc @@ -556,6 +556,9 @@ void AdaptiveFirFilter::Filter(const RenderBuffer& render_buffer, case Aec3Optimization::kSse2: aec3::ApplyFilter_Sse2(render_buffer, current_size_partitions_, H_, S); break; + case Aec3Optimization::kAvx2: + aec3::ApplyFilter_Avx2(render_buffer, current_size_partitions_, H_, S); + break; #endif #if defined(WEBRTC_HAS_NEON) case Aec3Optimization::kNeon: @@ -597,6 +600,9 @@ void AdaptiveFirFilter::ComputeFrequencyResponse( case Aec3Optimization::kSse2: aec3::ComputeFrequencyResponse_Sse2(current_size_partitions_, H_, H2); break; + case Aec3Optimization::kAvx2: + aec3::ComputeFrequencyResponse_Avx2(current_size_partitions_, H_, H2); + break; #endif #if defined(WEBRTC_HAS_NEON) case Aec3Optimization::kNeon: @@ -620,6 +626,10 @@ void AdaptiveFirFilter::AdaptAndUpdateSize(const RenderBuffer& render_buffer, aec3::AdaptPartitions_Sse2(render_buffer, G, current_size_partitions_, &H_); break; + case Aec3Optimization::kAvx2: + aec3::AdaptPartitions_Avx2(render_buffer, G, current_size_partitions_, + &H_); + break; #endif #if defined(WEBRTC_HAS_NEON) case Aec3Optimization::kNeon: diff --git a/modules/audio_processing/aec3/adaptive_fir_filter.h b/modules/audio_processing/aec3/adaptive_fir_filter.h index 2f6485340f..7597709460 100644 --- a/modules/audio_processing/aec3/adaptive_fir_filter.h +++ b/modules/audio_processing/aec3/adaptive_fir_filter.h @@ -42,6 +42,11 @@ void ComputeFrequencyResponse_Sse2( size_t num_partitions, const std::vector>& H, std::vector>* H2); + +void ComputeFrequencyResponse_Avx2( + size_t num_partitions, + const std::vector>& H, + std::vector>* H2); #endif // Adapts the filter partitions. @@ -60,6 +65,11 @@ void AdaptPartitions_Sse2(const RenderBuffer& render_buffer, const FftData& G, size_t num_partitions, std::vector>* H); + +void AdaptPartitions_Avx2(const RenderBuffer& render_buffer, + const FftData& G, + size_t num_partitions, + std::vector>* H); #endif // Produces the filter output. @@ -78,6 +88,11 @@ void ApplyFilter_Sse2(const RenderBuffer& render_buffer, size_t num_partitions, const std::vector>& H, FftData* S); + +void ApplyFilter_Avx2(const RenderBuffer& render_buffer, + size_t num_partitions, + const std::vector>& H, + FftData* S); #endif } // namespace aec3 diff --git a/modules/audio_processing/aec3/adaptive_fir_filter_avx2.cc b/modules/audio_processing/aec3/adaptive_fir_filter_avx2.cc new file mode 100644 index 0000000000..245b45ac31 --- /dev/null +++ b/modules/audio_processing/aec3/adaptive_fir_filter_avx2.cc @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/adaptive_fir_filter.h" + +#include + +#include "rtc_base/checks.h" + +namespace webrtc { + +namespace aec3 { + +// Computes and stores the frequency response of the filter. +void ComputeFrequencyResponse_Avx2( + size_t num_partitions, + const std::vector>& H, + std::vector>* H2) { + for (auto& H2_ch : *H2) { + H2_ch.fill(0.f); + } + + const size_t num_render_channels = H[0].size(); + RTC_DCHECK_EQ(H.size(), H2->capacity()); + for (size_t p = 0; p < num_partitions; ++p) { + RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size()); + for (size_t ch = 0; ch < num_render_channels; ++ch) { + for (size_t j = 0; j < kFftLengthBy2; j += 8) { + __m256 re = _mm256_loadu_ps(&H[p][ch].re[j]); + __m256 re2 = _mm256_mul_ps(re, re); + __m256 im = _mm256_loadu_ps(&H[p][ch].im[j]); + re2 = _mm256_fmadd_ps(im, im, re2); + __m256 H2_k_j = _mm256_loadu_ps(&(*H2)[p][j]); + H2_k_j = _mm256_max_ps(H2_k_j, re2); + _mm256_storeu_ps(&(*H2)[p][j], H2_k_j); + } + float H2_new = H[p][ch].re[kFftLengthBy2] * H[p][ch].re[kFftLengthBy2] + + H[p][ch].im[kFftLengthBy2] * H[p][ch].im[kFftLengthBy2]; + (*H2)[p][kFftLengthBy2] = std::max((*H2)[p][kFftLengthBy2], H2_new); + } + } +} + +// Adapts the filter partitions. +void AdaptPartitions_Avx2(const RenderBuffer& render_buffer, + const FftData& G, + size_t num_partitions, + std::vector>* H) { + rtc::ArrayView> render_buffer_data = + render_buffer.GetFftBuffer(); + const size_t num_render_channels = render_buffer_data[0].size(); + const size_t lim1 = std::min( + render_buffer_data.size() - render_buffer.Position(), num_partitions); + const size_t lim2 = num_partitions; + constexpr size_t kNumEightBinBands = kFftLengthBy2 / 8; + + size_t X_partition = render_buffer.Position(); + size_t limit = lim1; + size_t p = 0; + do { + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + FftData& H_p_ch = (*H)[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + + for (size_t k = 0, n = 0; n < kNumEightBinBands; ++n, k += 8) { + const __m256 G_re = _mm256_loadu_ps(&G.re[k]); + const __m256 G_im = _mm256_loadu_ps(&G.im[k]); + const __m256 X_re = _mm256_loadu_ps(&X.re[k]); + const __m256 X_im = _mm256_loadu_ps(&X.im[k]); + const __m256 H_re = _mm256_loadu_ps(&H_p_ch.re[k]); + const __m256 H_im = _mm256_loadu_ps(&H_p_ch.im[k]); + const __m256 a = _mm256_mul_ps(X_re, G_re); + const __m256 b = _mm256_mul_ps(X_im, G_im); + const __m256 c = _mm256_mul_ps(X_re, G_im); + const __m256 d = _mm256_mul_ps(X_im, G_re); + const __m256 e = _mm256_add_ps(a, b); + const __m256 f = _mm256_sub_ps(c, d); + const __m256 g = _mm256_add_ps(H_re, e); + const __m256 h = _mm256_add_ps(H_im, f); + _mm256_storeu_ps(&H_p_ch.re[k], g); + _mm256_storeu_ps(&H_p_ch.im[k], h); + } + } + } + X_partition = 0; + limit = lim2; + } while (p < lim2); + + X_partition = render_buffer.Position(); + limit = lim1; + p = 0; + do { + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + FftData& H_p_ch = (*H)[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + + H_p_ch.re[kFftLengthBy2] += X.re[kFftLengthBy2] * G.re[kFftLengthBy2] + + X.im[kFftLengthBy2] * G.im[kFftLengthBy2]; + H_p_ch.im[kFftLengthBy2] += X.re[kFftLengthBy2] * G.im[kFftLengthBy2] - + X.im[kFftLengthBy2] * G.re[kFftLengthBy2]; + } + } + + X_partition = 0; + limit = lim2; + } while (p < lim2); +} + +// Produces the filter output (AVX2 variant). +void ApplyFilter_Avx2(const RenderBuffer& render_buffer, + size_t num_partitions, + const std::vector>& H, + FftData* S) { + RTC_DCHECK_GE(H.size(), H.size() - 1); + S->re.fill(0.f); + S->im.fill(0.f); + + rtc::ArrayView> render_buffer_data = + render_buffer.GetFftBuffer(); + const size_t num_render_channels = render_buffer_data[0].size(); + const size_t lim1 = std::min( + render_buffer_data.size() - render_buffer.Position(), num_partitions); + const size_t lim2 = num_partitions; + constexpr size_t kNumEightBinBands = kFftLengthBy2 / 8; + + size_t X_partition = render_buffer.Position(); + size_t p = 0; + size_t limit = lim1; + do { + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& H_p_ch = H[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + for (size_t k = 0, n = 0; n < kNumEightBinBands; ++n, k += 8) { + const __m256 X_re = _mm256_loadu_ps(&X.re[k]); + const __m256 X_im = _mm256_loadu_ps(&X.im[k]); + const __m256 H_re = _mm256_loadu_ps(&H_p_ch.re[k]); + const __m256 H_im = _mm256_loadu_ps(&H_p_ch.im[k]); + const __m256 S_re = _mm256_loadu_ps(&S->re[k]); + const __m256 S_im = _mm256_loadu_ps(&S->im[k]); + const __m256 a = _mm256_mul_ps(X_re, H_re); + const __m256 b = _mm256_mul_ps(X_im, H_im); + const __m256 c = _mm256_mul_ps(X_re, H_im); + const __m256 d = _mm256_mul_ps(X_im, H_re); + const __m256 e = _mm256_sub_ps(a, b); + const __m256 f = _mm256_add_ps(c, d); + const __m256 g = _mm256_add_ps(S_re, e); + const __m256 h = _mm256_add_ps(S_im, f); + _mm256_storeu_ps(&S->re[k], g); + _mm256_storeu_ps(&S->im[k], h); + } + } + } + limit = lim2; + X_partition = 0; + } while (p < lim2); + + X_partition = render_buffer.Position(); + p = 0; + limit = lim1; + do { + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& H_p_ch = H[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + S->re[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2] - + X.im[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2]; + S->im[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2] + + X.im[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2]; + } + } + limit = lim2; + X_partition = 0; + } while (p < lim2); +} + +} // namespace aec3 +} // namespace webrtc diff --git a/modules/audio_processing/aec3/adaptive_fir_filter_erl.cc b/modules/audio_processing/aec3/adaptive_fir_filter_erl.cc index 80378eb3cf..45b8813979 100644 --- a/modules/audio_processing/aec3/adaptive_fir_filter_erl.cc +++ b/modules/audio_processing/aec3/adaptive_fir_filter_erl.cc @@ -85,10 +85,12 @@ void ComputeErl(const Aec3Optimization& optimization, case Aec3Optimization::kSse2: aec3::ErlComputer_SSE2(H2, erl); break; + case Aec3Optimization::kAvx2: + aec3::ErlComputer_AVX2(H2, erl); + break; #endif #if defined(WEBRTC_HAS_NEON) case Aec3Optimization::kNeon: - aec3::ErlComputer_NEON(H2, erl); break; #endif diff --git a/modules/audio_processing/aec3/adaptive_fir_filter_erl.h b/modules/audio_processing/aec3/adaptive_fir_filter_erl.h index 108d9f8e44..4ac13b1bc3 100644 --- a/modules/audio_processing/aec3/adaptive_fir_filter_erl.h +++ b/modules/audio_processing/aec3/adaptive_fir_filter_erl.h @@ -36,6 +36,10 @@ void ErlComputer_NEON( void ErlComputer_SSE2( const std::vector>& H2, rtc::ArrayView erl); + +void ErlComputer_AVX2( + const std::vector>& H2, + rtc::ArrayView erl); #endif } // namespace aec3 diff --git a/modules/audio_processing/aec3/adaptive_fir_filter_erl_avx2.cc b/modules/audio_processing/aec3/adaptive_fir_filter_erl_avx2.cc new file mode 100644 index 0000000000..5fe7514db1 --- /dev/null +++ b/modules/audio_processing/aec3/adaptive_fir_filter_erl_avx2.cc @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h" + +#include + +namespace webrtc { + +namespace aec3 { + +// Computes and stores the echo return loss estimate of the filter, which is the +// sum of the partition frequency responses. +void ErlComputer_AVX2( + const std::vector>& H2, + rtc::ArrayView erl) { + std::fill(erl.begin(), erl.end(), 0.f); + for (auto& H2_j : H2) { + for (size_t k = 0; k < kFftLengthBy2; k += 8) { + const __m256 H2_j_k = _mm256_loadu_ps(&H2_j[k]); + __m256 erl_k = _mm256_loadu_ps(&erl[k]); + erl_k = _mm256_add_ps(erl_k, H2_j_k); + _mm256_storeu_ps(&erl[k], erl_k); + } + erl[kFftLengthBy2] += H2_j[kFftLengthBy2]; + } +} + +} // namespace aec3 +} // namespace webrtc diff --git a/modules/audio_processing/aec3/adaptive_fir_filter_erl_unittest.cc b/modules/audio_processing/aec3/adaptive_fir_filter_erl_unittest.cc index 069fc9fa5b..fc30f7fd74 100644 --- a/modules/audio_processing/aec3/adaptive_fir_filter_erl_unittest.cc +++ b/modules/audio_processing/aec3/adaptive_fir_filter_erl_unittest.cc @@ -75,6 +75,31 @@ TEST(AdaptiveFirFilter, UpdateErlSse2Optimization) { } } +// Verifies that the optimized method for echo return loss computation is +// bitexact to the reference counterpart. +TEST(AdaptiveFirFilter, UpdateErlAvx2Optimization) { + bool use_avx2 = (WebRtc_GetCPUInfo(kAVX2) != 0); + if (use_avx2) { + const size_t kNumPartitions = 12; + std::vector> H2(kNumPartitions); + std::array erl; + std::array erl_AVX2; + + for (size_t j = 0; j < H2.size(); ++j) { + for (size_t k = 0; k < H2[j].size(); ++k) { + H2[j][k] = k + j / 3.f; + } + } + + ErlComputer(H2, erl); + ErlComputer_AVX2(H2, erl_AVX2); + + for (size_t j = 0; j < erl.size(); ++j) { + EXPECT_FLOAT_EQ(erl[j], erl_AVX2[j]); + } + } +} + #endif } // namespace aec3 diff --git a/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc b/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc index 39f4e11192..a18ebd86a1 100644 --- a/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc +++ b/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc @@ -246,6 +246,81 @@ TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels, } } +// Verifies that the optimized methods for filter adaptation are bitexact to +// their reference counterparts. +TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels, + FilterAdaptationAvx2Optimizations) { + const size_t num_render_channels = GetParam(); + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + + bool use_avx2 = (WebRtc_GetCPUInfo(kAVX2) != 0); + if (use_avx2) { + for (size_t num_partitions : {2, 5, 12, 30, 50}) { + std::unique_ptr render_delay_buffer( + RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz, + num_render_channels)); + Random random_generator(42U); + std::vector>> x( + kNumBands, + std::vector>(num_render_channels, + std::vector(kBlockSize, 0.f))); + FftData S_C; + FftData S_Avx2; + FftData G; + Aec3Fft fft; + std::vector> H_C( + num_partitions, std::vector(num_render_channels)); + std::vector> H_Avx2( + num_partitions, std::vector(num_render_channels)); + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + H_C[p][ch].Clear(); + H_Avx2[p][ch].Clear(); + } + } + + for (size_t k = 0; k < 500; ++k) { + for (size_t band = 0; band < x.size(); ++band) { + for (size_t ch = 0; ch < x[band].size(); ++ch) { + RandomizeSampleVector(&random_generator, x[band][ch]); + } + } + render_delay_buffer->Insert(x); + if (k == 0) { + render_delay_buffer->Reset(); + } + render_delay_buffer->PrepareCaptureProcessing(); + auto* const render_buffer = render_delay_buffer->GetRenderBuffer(); + + ApplyFilter_Avx2(*render_buffer, num_partitions, H_Avx2, &S_Avx2); + ApplyFilter(*render_buffer, num_partitions, H_C, &S_C); + for (size_t j = 0; j < S_C.re.size(); ++j) { + EXPECT_FLOAT_EQ(S_C.re[j], S_Avx2.re[j]); + EXPECT_FLOAT_EQ(S_C.im[j], S_Avx2.im[j]); + } + + std::for_each(G.re.begin(), G.re.end(), + [&](float& a) { a = random_generator.Rand(); }); + std::for_each(G.im.begin(), G.im.end(), + [&](float& a) { a = random_generator.Rand(); }); + + AdaptPartitions_Avx2(*render_buffer, G, num_partitions, &H_Avx2); + AdaptPartitions(*render_buffer, G, num_partitions, &H_C); + + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + for (size_t j = 0; j < H_C[p][ch].re.size(); ++j) { + EXPECT_FLOAT_EQ(H_C[p][ch].re[j], H_Avx2[p][ch].re[j]); + EXPECT_FLOAT_EQ(H_C[p][ch].im[j], H_Avx2[p][ch].im[j]); + } + } + } + } + } + } +} + // Verifies that the optimized method for frequency response computation is // bitexact to the reference counterpart. TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels, @@ -281,6 +356,41 @@ TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels, } } +// Verifies that the optimized method for frequency response computation is +// bitexact to the reference counterpart. +TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels, + ComputeFrequencyResponseAvx2Optimization) { + const size_t num_render_channels = GetParam(); + bool use_avx2 = (WebRtc_GetCPUInfo(kAVX2) != 0); + if (use_avx2) { + for (size_t num_partitions : {2, 5, 12, 30, 50}) { + std::vector> H( + num_partitions, std::vector(num_render_channels)); + std::vector> H2(num_partitions); + std::vector> H2_Avx2( + num_partitions); + + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + for (size_t k = 0; k < H[p][ch].re.size(); ++k) { + H[p][ch].re[k] = k + p / 3.f + ch; + H[p][ch].im[k] = p + k / 7.f - ch; + } + } + } + + ComputeFrequencyResponse(num_partitions, H, &H2); + ComputeFrequencyResponse_Avx2(num_partitions, H, &H2_Avx2); + + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t k = 0; k < H2[p].size(); ++k) { + EXPECT_FLOAT_EQ(H2[p][k], H2_Avx2[p][k]); + } + } + } + } +} + #endif #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) diff --git a/modules/audio_processing/aec3/aec3_common.cc b/modules/audio_processing/aec3/aec3_common.cc index aeb848a570..6aaab619fc 100644 --- a/modules/audio_processing/aec3/aec3_common.cc +++ b/modules/audio_processing/aec3/aec3_common.cc @@ -20,7 +20,9 @@ namespace webrtc { Aec3Optimization DetectOptimization() { #if defined(WEBRTC_ARCH_X86_FAMILY) - if (WebRtc_GetCPUInfo(kSSE2) != 0) { + if (WebRtc_GetCPUInfo(kAVX2) != 0) { + return Aec3Optimization::kAvx2; + } else if (WebRtc_GetCPUInfo(kSSE2) != 0) { return Aec3Optimization::kSse2; } #endif diff --git a/modules/audio_processing/aec3/aec3_common.h b/modules/audio_processing/aec3/aec3_common.h index cdeefc7046..a7e3121138 100644 --- a/modules/audio_processing/aec3/aec3_common.h +++ b/modules/audio_processing/aec3/aec3_common.h @@ -23,7 +23,7 @@ namespace webrtc { #define ALIGN16_END __attribute__((aligned(16))) #endif -enum class Aec3Optimization { kNone, kSse2, kNeon }; +enum class Aec3Optimization { kNone, kSse2, kAvx2, kNeon }; constexpr int kNumBlocksPerSecond = 250; diff --git a/modules/audio_processing/aec3/fft_data.h b/modules/audio_processing/aec3/fft_data.h index 5e5adb62de..9c25e784aa 100644 --- a/modules/audio_processing/aec3/fft_data.h +++ b/modules/audio_processing/aec3/fft_data.h @@ -40,6 +40,9 @@ struct FftData { im.fill(0.f); } + // Computes the power spectrum of the data. + void SpectrumAVX2(rtc::ArrayView power_spectrum) const; + // Computes the power spectrum of the data. void Spectrum(Aec3Optimization optimization, rtc::ArrayView power_spectrum) const { @@ -60,6 +63,9 @@ struct FftData { power_spectrum[kFftLengthBy2] = re[kFftLengthBy2] * re[kFftLengthBy2] + im[kFftLengthBy2] * im[kFftLengthBy2]; } break; + case Aec3Optimization::kAvx2: + SpectrumAVX2(power_spectrum); + break; #endif default: std::transform(re.begin(), re.end(), im.begin(), power_spectrum.begin(), diff --git a/modules/audio_processing/aec3/fft_data_avx2.cc b/modules/audio_processing/aec3/fft_data_avx2.cc new file mode 100644 index 0000000000..1fe4bd69c6 --- /dev/null +++ b/modules/audio_processing/aec3/fft_data_avx2.cc @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/fft_data.h" + +#include + +#include "api/array_view.h" + +namespace webrtc { + +// Computes the power spectrum of the data. +void FftData::SpectrumAVX2(rtc::ArrayView power_spectrum) const { + RTC_DCHECK_EQ(kFftLengthBy2Plus1, power_spectrum.size()); + for (size_t k = 0; k < kFftLengthBy2; k += 8) { + __m256 r = _mm256_loadu_ps(&re[k]); + __m256 i = _mm256_loadu_ps(&im[k]); + __m256 ii = _mm256_mul_ps(i, i); + ii = _mm256_fmadd_ps(r, r, ii); + _mm256_storeu_ps(&power_spectrum[k], ii); + } + power_spectrum[kFftLengthBy2] = re[kFftLengthBy2] * re[kFftLengthBy2] + + im[kFftLengthBy2] * im[kFftLengthBy2]; +} + +} // namespace webrtc diff --git a/modules/audio_processing/aec3/fft_data_unittest.cc b/modules/audio_processing/aec3/fft_data_unittest.cc index 9be2680453..b0235a7fd0 100644 --- a/modules/audio_processing/aec3/fft_data_unittest.cc +++ b/modules/audio_processing/aec3/fft_data_unittest.cc @@ -19,7 +19,7 @@ namespace webrtc { #if defined(WEBRTC_ARCH_X86_FAMILY) // Verifies that the optimized methods are bitexact to their reference // counterparts. -TEST(FftData, TestOptimizations) { +TEST(FftData, TestSse2Optimizations) { if (WebRtc_GetCPUInfo(kSSE2) != 0) { FftData x; @@ -39,6 +39,29 @@ TEST(FftData, TestOptimizations) { EXPECT_EQ(spectrum, spectrum_sse2); } } + +// Verifies that the optimized methods are bitexact to their reference +// counterparts. +TEST(FftData, TestAvx2Optimizations) { + if (WebRtc_GetCPUInfo(kAVX2) != 0) { + FftData x; + + for (size_t k = 0; k < x.re.size(); ++k) { + x.re[k] = k + 1; + } + + x.im[0] = x.im[x.im.size() - 1] = 0.f; + for (size_t k = 1; k < x.im.size() - 1; ++k) { + x.im[k] = 2.f * (k + 1); + } + + std::array spectrum; + std::array spectrum_avx2; + x.Spectrum(Aec3Optimization::kNone, spectrum); + x.Spectrum(Aec3Optimization::kAvx2, spectrum_avx2); + EXPECT_EQ(spectrum, spectrum_avx2); + } +} #endif #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) diff --git a/modules/audio_processing/aec3/matched_filter.cc b/modules/audio_processing/aec3/matched_filter.cc index 2a489923b1..64b2d4e697 100644 --- a/modules/audio_processing/aec3/matched_filter.cc +++ b/modules/audio_processing/aec3/matched_filter.cc @@ -364,6 +364,11 @@ void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer, smoothing_, render_buffer.buffer, y, filters_[n], &filters_updated, &error_sum); break; + case Aec3Optimization::kAvx2: + aec3::MatchedFilterCore_AVX2(x_start_index, x2_sum_threshold, + smoothing_, render_buffer.buffer, y, + filters_[n], &filters_updated, &error_sum); + break; #endif #if defined(WEBRTC_HAS_NEON) case Aec3Optimization::kNeon: diff --git a/modules/audio_processing/aec3/matched_filter.h b/modules/audio_processing/aec3/matched_filter.h index df9245322f..954e4784f2 100644 --- a/modules/audio_processing/aec3/matched_filter.h +++ b/modules/audio_processing/aec3/matched_filter.h @@ -53,6 +53,16 @@ void MatchedFilterCore_SSE2(size_t x_start_index, bool* filters_updated, float* error_sum); +// Filter core for the matched filter that is optimized for AVX2. +void MatchedFilterCore_AVX2(size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView x, + rtc::ArrayView y, + rtc::ArrayView h, + bool* filters_updated, + float* error_sum); + #endif // Filter core for the matched filter. diff --git a/modules/audio_processing/aec3/matched_filter_avx2.cc b/modules/audio_processing/aec3/matched_filter_avx2.cc new file mode 100644 index 0000000000..ed32102aa4 --- /dev/null +++ b/modules/audio_processing/aec3/matched_filter_avx2.cc @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/matched_filter.h" + +#include + +#include "rtc_base/checks.h" + +namespace webrtc { +namespace aec3 { + +void MatchedFilterCore_AVX2(size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView x, + rtc::ArrayView y, + rtc::ArrayView h, + bool* filters_updated, + float* error_sum) { + const int h_size = static_cast(h.size()); + const int x_size = static_cast(x.size()); + RTC_DCHECK_EQ(0, h_size % 8); + + // Process for all samples in the sub-block. + for (size_t i = 0; i < y.size(); ++i) { + // Apply the matched filter as filter * x, and compute x * x. + + RTC_DCHECK_GT(x_size, x_start_index); + const float* x_p = &x[x_start_index]; + const float* h_p = &h[0]; + + // Initialize values for the accumulation. + __m256 s_256 = _mm256_set1_ps(0); + __m256 x2_sum_256 = _mm256_set1_ps(0); + float x2_sum = 0.f; + float s = 0; + + // Compute loop chunk sizes until, and after, the wraparound of the circular + // buffer for x. + const int chunk1 = + std::min(h_size, static_cast(x_size - x_start_index)); + + // Perform the loop in two chunks. + const int chunk2 = h_size - chunk1; + for (int limit : {chunk1, chunk2}) { + // Perform 256 bit vector operations. + const int limit_by_8 = limit >> 3; + for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) { + // Load the data into 256 bit vectors. + __m256 x_k = _mm256_loadu_ps(x_p); + __m256 h_k = _mm256_loadu_ps(h_p); + // Compute and accumulate x * x and h * x. + x2_sum_256 = _mm256_fmadd_ps(x_k, x_k, x2_sum_256); + s_256 = _mm256_fmadd_ps(h_k, x_k, s_256); + } + + // Perform non-vector operations for any remaining items. + for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) { + const float x_k = *x_p; + x2_sum += x_k * x_k; + s += *h_p * x_k; + } + + x_p = &x[0]; + } + + // Sum components together. + __m128 x2_sum_128 = _mm_add_ps(_mm256_extractf128_ps(x2_sum_256, 0), + _mm256_extractf128_ps(x2_sum_256, 1)); + __m128 s_128 = _mm_add_ps(_mm256_extractf128_ps(s_256, 0), + _mm256_extractf128_ps(s_256, 1)); + // Combine the accumulated vector and scalar values. + float* v = reinterpret_cast(&x2_sum_128); + x2_sum += v[0] + v[1] + v[2] + v[3]; + v = reinterpret_cast(&s_128); + s += v[0] + v[1] + v[2] + v[3]; + + // Compute the matched filter error. + float e = y[i] - s; + const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; + (*error_sum) += e * e; + + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold && !saturation) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = smoothing * e / x2_sum; + const __m256 alpha_256 = _mm256_set1_ps(alpha); + + // filter = filter + smoothing * (y - filter * x) * x / x * x. + float* h_p = &h[0]; + x_p = &x[x_start_index]; + + // Perform the loop in two chunks. + for (int limit : {chunk1, chunk2}) { + // Perform 256 bit vector operations. + const int limit_by_8 = limit >> 3; + for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) { + // Load the data into 256 bit vectors. + __m256 h_k = _mm256_loadu_ps(h_p); + __m256 x_k = _mm256_loadu_ps(x_p); + // Compute h = h + alpha * x. + h_k = _mm256_fmadd_ps(x_k, alpha_256, h_k); + + // Store the result. + _mm256_storeu_ps(h_p, h_k); + } + + // Perform non-vector operations for any remaining items. + for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) { + *h_p += alpha * *x_p; + } + + x_p = &x[0]; + } + + *filters_updated = true; + } + + x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; + } +} + +} // namespace aec3 +} // namespace webrtc diff --git a/modules/audio_processing/aec3/matched_filter_unittest.cc b/modules/audio_processing/aec3/matched_filter_unittest.cc index 7d9a7d4d0a..7e16c01ed9 100644 --- a/modules/audio_processing/aec3/matched_filter_unittest.cc +++ b/modules/audio_processing/aec3/matched_filter_unittest.cc @@ -133,6 +133,47 @@ TEST(MatchedFilter, TestSse2Optimizations) { } } +TEST(MatchedFilter, TestAvx2Optimizations) { + bool use_avx2 = (WebRtc_GetCPUInfo(kAVX2) != 0); + if (use_avx2) { + Random random_generator(42U); + constexpr float kSmoothing = 0.7f; + for (auto down_sampling_factor : kDownSamplingFactors) { + const size_t sub_block_size = kBlockSize / down_sampling_factor; + std::vector x(2000); + RandomizeSampleVector(&random_generator, x); + std::vector y(sub_block_size); + std::vector h_AVX2(512); + std::vector h(512); + int x_index = 0; + for (int k = 0; k < 1000; ++k) { + RandomizeSampleVector(&random_generator, y); + + bool filters_updated = false; + float error_sum = 0.f; + bool filters_updated_AVX2 = false; + float error_sum_AVX2 = 0.f; + + MatchedFilterCore_AVX2(x_index, h.size() * 150.f * 150.f, kSmoothing, x, + y, h_AVX2, &filters_updated_AVX2, + &error_sum_AVX2); + + MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y, + h, &filters_updated, &error_sum); + + EXPECT_EQ(filters_updated, filters_updated_AVX2); + EXPECT_NEAR(error_sum, error_sum_AVX2, error_sum / 100000.f); + + for (size_t j = 0; j < h.size(); ++j) { + EXPECT_NEAR(h[j], h_AVX2[j], 0.00001f); + } + + x_index = (x_index + sub_block_size) % x.size(); + } + } + } +} + #endif // Verifies that the matched filter produces proper lag estimates for diff --git a/modules/audio_processing/aec3/vector_math.h b/modules/audio_processing/aec3/vector_math.h index 883cd95fdd..e4d1381ae1 100644 --- a/modules/audio_processing/aec3/vector_math.h +++ b/modules/audio_processing/aec3/vector_math.h @@ -40,6 +40,7 @@ class VectorMath { : optimization_(optimization) {} // Elementwise square root. + void SqrtAVX2(rtc::ArrayView x); void Sqrt(rtc::ArrayView x) { switch (optimization_) { #if defined(WEBRTC_ARCH_X86_FAMILY) @@ -58,6 +59,9 @@ class VectorMath { x[j] = sqrtf(x[j]); } } break; + case Aec3Optimization::kAvx2: + SqrtAVX2(x); + break; #endif #if defined(WEBRTC_HAS_NEON) case Aec3Optimization::kNeon: { @@ -110,6 +114,9 @@ class VectorMath { } // Elementwise vector multiplication z = x * y. + void MultiplyAVX2(rtc::ArrayView x, + rtc::ArrayView y, + rtc::ArrayView z); void Multiply(rtc::ArrayView x, rtc::ArrayView y, rtc::ArrayView z) { @@ -133,6 +140,9 @@ class VectorMath { z[j] = x[j] * y[j]; } } break; + case Aec3Optimization::kAvx2: + MultiplyAVX2(x, y, z); + break; #endif #if defined(WEBRTC_HAS_NEON) case Aec3Optimization::kNeon: { @@ -159,6 +169,7 @@ class VectorMath { } // Elementwise vector accumulation z += x. + void AccumulateAVX2(rtc::ArrayView x, rtc::ArrayView z); void Accumulate(rtc::ArrayView x, rtc::ArrayView z) { RTC_DCHECK_EQ(z.size(), x.size()); switch (optimization_) { @@ -179,6 +190,9 @@ class VectorMath { z[j] += x[j]; } } break; + case Aec3Optimization::kAvx2: + AccumulateAVX2(x, z); + break; #endif #if defined(WEBRTC_HAS_NEON) case Aec3Optimization::kNeon: { diff --git a/modules/audio_processing/aec3/vector_math_avx2.cc b/modules/audio_processing/aec3/vector_math_avx2.cc new file mode 100644 index 0000000000..0b5f3c142e --- /dev/null +++ b/modules/audio_processing/aec3/vector_math_avx2.cc @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/vector_math.h" + +#include +#include + +#include "api/array_view.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace aec3 { + +// Elementwise square root. +void VectorMath::SqrtAVX2(rtc::ArrayView x) { + const int x_size = static_cast(x.size()); + const int vector_limit = x_size >> 3; + + int j = 0; + for (; j < vector_limit * 8; j += 8) { + __m256 g = _mm256_loadu_ps(&x[j]); + g = _mm256_sqrt_ps(g); + _mm256_storeu_ps(&x[j], g); + } + + for (; j < x_size; ++j) { + x[j] = sqrtf(x[j]); + } +} + +// Elementwise vector multiplication z = x * y. +void VectorMath::MultiplyAVX2(rtc::ArrayView x, + rtc::ArrayView y, + rtc::ArrayView z) { + RTC_DCHECK_EQ(z.size(), x.size()); + RTC_DCHECK_EQ(z.size(), y.size()); + const int x_size = static_cast(x.size()); + const int vector_limit = x_size >> 3; + + int j = 0; + for (; j < vector_limit * 8; j += 8) { + const __m256 x_j = _mm256_loadu_ps(&x[j]); + const __m256 y_j = _mm256_loadu_ps(&y[j]); + const __m256 z_j = _mm256_mul_ps(x_j, y_j); + _mm256_storeu_ps(&z[j], z_j); + } + + for (; j < x_size; ++j) { + z[j] = x[j] * y[j]; + } +} + +// Elementwise vector accumulation z += x. +void VectorMath::AccumulateAVX2(rtc::ArrayView x, + rtc::ArrayView z) { + RTC_DCHECK_EQ(z.size(), x.size()); + const int x_size = static_cast(x.size()); + const int vector_limit = x_size >> 3; + + int j = 0; + for (; j < vector_limit * 8; j += 8) { + const __m256 x_j = _mm256_loadu_ps(&x[j]); + __m256 z_j = _mm256_loadu_ps(&z[j]); + z_j = _mm256_add_ps(x_j, z_j); + _mm256_storeu_ps(&z[j], z_j); + } + + for (; j < x_size; ++j) { + z[j] += x[j]; + } +} + +} // namespace aec3 +} // namespace webrtc diff --git a/modules/audio_processing/aec3/vector_math_unittest.cc b/modules/audio_processing/aec3/vector_math_unittest.cc index fdab2e52ca..bd156b579c 100644 --- a/modules/audio_processing/aec3/vector_math_unittest.cc +++ b/modules/audio_processing/aec3/vector_math_unittest.cc @@ -79,7 +79,7 @@ TEST(VectorMath, Accumulate) { #if defined(WEBRTC_ARCH_X86_FAMILY) -TEST(VectorMath, Sqrt) { +TEST(VectorMath, Sse2Sqrt) { if (WebRtc_GetCPUInfo(kSSE2) != 0) { std::array x; std::array z; @@ -101,7 +101,29 @@ TEST(VectorMath, Sqrt) { } } -TEST(VectorMath, Multiply) { +TEST(VectorMath, Avx2Sqrt) { + if (WebRtc_GetCPUInfo(kAVX2) != 0) { + std::array x; + std::array z; + std::array z_avx2; + + for (size_t k = 0; k < x.size(); ++k) { + x[k] = (2.f / 3.f) * k; + } + + std::copy(x.begin(), x.end(), z.begin()); + aec3::VectorMath(Aec3Optimization::kNone).Sqrt(z); + std::copy(x.begin(), x.end(), z_avx2.begin()); + aec3::VectorMath(Aec3Optimization::kAvx2).Sqrt(z_avx2); + EXPECT_EQ(z, z_avx2); + for (size_t k = 0; k < z.size(); ++k) { + EXPECT_FLOAT_EQ(z[k], z_avx2[k]); + EXPECT_FLOAT_EQ(sqrtf(x[k]), z_avx2[k]); + } + } +} + +TEST(VectorMath, Sse2Multiply) { if (WebRtc_GetCPUInfo(kSSE2) != 0) { std::array x; std::array y; @@ -122,7 +144,28 @@ TEST(VectorMath, Multiply) { } } -TEST(VectorMath, Accumulate) { +TEST(VectorMath, Avx2Multiply) { + if (WebRtc_GetCPUInfo(kAVX2) != 0) { + std::array x; + std::array y; + std::array z; + std::array z_avx2; + + for (size_t k = 0; k < x.size(); ++k) { + x[k] = k; + y[k] = (2.f / 3.f) * k; + } + + aec3::VectorMath(Aec3Optimization::kNone).Multiply(x, y, z); + aec3::VectorMath(Aec3Optimization::kAvx2).Multiply(x, y, z_avx2); + for (size_t k = 0; k < z.size(); ++k) { + EXPECT_FLOAT_EQ(z[k], z_avx2[k]); + EXPECT_FLOAT_EQ(x[k] * y[k], z_avx2[k]); + } + } +} + +TEST(VectorMath, Sse2Accumulate) { if (WebRtc_GetCPUInfo(kSSE2) != 0) { std::array x; std::array z; @@ -141,6 +184,26 @@ TEST(VectorMath, Accumulate) { } } } + +TEST(VectorMath, Avx2Accumulate) { + if (WebRtc_GetCPUInfo(kAVX2) != 0) { + std::array x; + std::array z; + std::array z_avx2; + + for (size_t k = 0; k < x.size(); ++k) { + x[k] = k; + z[k] = z_avx2[k] = 2.f * k; + } + + aec3::VectorMath(Aec3Optimization::kNone).Accumulate(x, z); + aec3::VectorMath(Aec3Optimization::kAvx2).Accumulate(x, z_avx2); + for (size_t k = 0; k < z.size(); ++k) { + EXPECT_FLOAT_EQ(z[k], z_avx2[k]); + EXPECT_FLOAT_EQ(x[k] + 2.f * x[k], z_avx2[k]); + } + } +} #endif } // namespace webrtc