From f246b91eba0e8d95bd3fee4634887fb6d3017811 Mon Sep 17 00:00:00 2001 From: peah Date: Wed, 3 May 2017 06:28:59 -0700 Subject: [PATCH] Added ARM Neon optimizations for AEC3 This CL adds Neon SIMD optimizations for AEC3 on ARM, resulting in an 8 times complexity reduction. The optimizations are basically identical to what was already in place for SSE2. BUG=chromium:14993, webrtc:6018 Review-Url: https://codereview.webrtc.org/2834073005 Cr-Commit-Position: refs/heads/master@{#17993} --- .../aec3/adaptive_fir_filter.cc | 174 ++++++++++++++++++ .../aec3/adaptive_fir_filter.h | 20 ++ .../aec3/adaptive_fir_filter_unittest.cc | 111 ++++++++++- .../audio_processing/aec3/aec3_common.cc | 5 + .../audio_processing/aec3/aec3_common.h | 2 +- .../audio_processing/aec3/matched_filter.cc | 118 ++++++++++++ .../audio_processing/aec3/matched_filter.h | 13 ++ .../aec3/matched_filter_unittest.cc | 43 ++++- .../audio_processing/aec3/vector_math.h | 84 +++++++++ .../aec3/vector_math_unittest.cc | 59 ++++++ 10 files changed, 622 insertions(+), 7 deletions(-) diff --git a/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.cc b/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.cc index 3174fa762e..7c29558c7e 100644 --- a/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.cc +++ b/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.cc @@ -10,6 +10,9 @@ #include "webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h" +#if defined(WEBRTC_HAS_NEON) +#include +#endif #include "webrtc/typedefs.h" #if defined(WEBRTC_ARCH_X86_FAMILY) #include @@ -52,6 +55,26 @@ void UpdateFrequencyResponse( } } +#if defined(WEBRTC_HAS_NEON) +// Computes and stores the frequency response of the filter. +void UpdateFrequencyResponse_NEON( + rtc::ArrayView H, + std::vector>* H2) { + RTC_DCHECK_EQ(H.size(), H2->size()); + for (size_t k = 0; k < H.size(); ++k) { + for (size_t j = 0; j < kFftLengthBy2; j += 4) { + const float32x4_t re = vld1q_f32(&H[k].re[j]); + const float32x4_t im = vld1q_f32(&H[k].im[j]); + float32x4_t H2_k_j = vmulq_f32(re, re); + H2_k_j = vmlaq_f32(H2_k_j, im, im); + vst1q_f32(&(*H2)[k][j], H2_k_j); + } + (*H2)[k][kFftLengthBy2] = H[k].re[kFftLengthBy2] * H[k].re[kFftLengthBy2] + + H[k].im[kFftLengthBy2] * H[k].im[kFftLengthBy2]; + } +} +#endif + #if defined(WEBRTC_ARCH_X86_FAMILY) // Computes and stores the frequency response of the filter. void UpdateFrequencyResponse_SSE2( @@ -85,6 +108,25 @@ void UpdateErlEstimator( } } +#if defined(WEBRTC_HAS_NEON) +// Computes and stores the echo return loss estimate of the filter, which is the +// sum of the partition frequency responses. +void UpdateErlEstimator_NEON( + const std::vector>& H2, + std::array* erl) { + erl->fill(0.f); + for (auto& H2_j : H2) { + for (size_t k = 0; k < kFftLengthBy2; k += 4) { + const float32x4_t H2_j_k = vld1q_f32(&H2_j[k]); + float32x4_t erl_k = vld1q_f32(&(*erl)[k]); + erl_k = vaddq_f32(erl_k, H2_j_k); + vst1q_f32(&(*erl)[k], erl_k); + } + (*erl)[kFftLengthBy2] += H2_j[kFftLengthBy2]; + } +} +#endif + #if defined(WEBRTC_ARCH_X86_FAMILY) // Computes and stores the echo return loss estimate of the filter, which is the // sum of the partition frequency responses. @@ -121,6 +163,63 @@ void AdaptPartitions(const RenderBuffer& render_buffer, } } +#if defined(WEBRTC_HAS_NEON) +// Adapts the filter partitions. (NEON variant) +void AdaptPartitions_NEON(const RenderBuffer& render_buffer, + const FftData& G, + rtc::ArrayView H) { + rtc::ArrayView render_buffer_data = render_buffer.Buffer(); + const int lim1 = + std::min(render_buffer_data.size() - render_buffer.Position(), H.size()); + const int lim2 = H.size(); + constexpr int kNumFourBinBands = kFftLengthBy2 / 4; + FftData* H_j = &H[0]; + const FftData* X = &render_buffer_data[render_buffer.Position()]; + int limit = lim1; + int j = 0; + do { + for (; j < limit; ++j, ++H_j, ++X) { + for (int k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) { + const float32x4_t G_re = vld1q_f32(&G.re[k]); + const float32x4_t G_im = vld1q_f32(&G.im[k]); + const float32x4_t X_re = vld1q_f32(&X->re[k]); + const float32x4_t X_im = vld1q_f32(&X->im[k]); + const float32x4_t H_re = vld1q_f32(&H_j->re[k]); + const float32x4_t H_im = vld1q_f32(&H_j->im[k]); + const float32x4_t a = vmulq_f32(X_re, G_re); + const float32x4_t e = vmlaq_f32(a, X_im, G_im); + const float32x4_t c = vmulq_f32(X_re, G_im); + const float32x4_t f = vmlsq_f32(c, X_im, G_re); + const float32x4_t g = vaddq_f32(H_re, e); + const float32x4_t h = vaddq_f32(H_im, f); + + vst1q_f32(&H_j->re[k], g); + vst1q_f32(&H_j->im[k], h); + } + } + + X = &render_buffer_data[0]; + limit = lim2; + } while (j < lim2); + + H_j = &H[0]; + X = &render_buffer_data[render_buffer.Position()]; + limit = lim1; + j = 0; + do { + for (; j < limit; ++j, ++H_j, ++X) { + H_j->re[kFftLengthBy2] += X->re[kFftLengthBy2] * G.re[kFftLengthBy2] + + X->im[kFftLengthBy2] * G.im[kFftLengthBy2]; + H_j->im[kFftLengthBy2] += X->re[kFftLengthBy2] * G.im[kFftLengthBy2] - + X->im[kFftLengthBy2] * G.re[kFftLengthBy2]; + } + + X = &render_buffer_data[0]; + limit = lim2; + } while (j < lim2); +} +#endif + #if defined(WEBRTC_ARCH_X86_FAMILY) // Adapts the filter partitions. (SSE2 variant) void AdaptPartitions_SSE2(const RenderBuffer& render_buffer, @@ -203,6 +302,65 @@ void ApplyFilter(const RenderBuffer& render_buffer, } } +#if defined(WEBRTC_HAS_NEON) +// Produces the filter output (NEON variant). +void ApplyFilter_NEON(const RenderBuffer& render_buffer, + rtc::ArrayView H, + FftData* S) { + RTC_DCHECK_GE(H.size(), H.size() - 1); + S->re.fill(0.f); + S->im.fill(0.f); + + rtc::ArrayView render_buffer_data = render_buffer.Buffer(); + const int lim1 = + std::min(render_buffer_data.size() - render_buffer.Position(), H.size()); + const int lim2 = H.size(); + constexpr int kNumFourBinBands = kFftLengthBy2 / 4; + const FftData* H_j = &H[0]; + const FftData* X = &render_buffer_data[render_buffer.Position()]; + + int j = 0; + int limit = lim1; + do { + for (; j < limit; ++j, ++H_j, ++X) { + for (int k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) { + const float32x4_t X_re = vld1q_f32(&X->re[k]); + const float32x4_t X_im = vld1q_f32(&X->im[k]); + const float32x4_t H_re = vld1q_f32(&H_j->re[k]); + const float32x4_t H_im = vld1q_f32(&H_j->im[k]); + const float32x4_t S_re = vld1q_f32(&S->re[k]); + const float32x4_t S_im = vld1q_f32(&S->im[k]); + const float32x4_t a = vmulq_f32(X_re, H_re); + const float32x4_t e = vmlsq_f32(a, X_im, H_im); + const float32x4_t c = vmulq_f32(X_re, H_im); + const float32x4_t f = vmlaq_f32(c, X_im, H_re); + const float32x4_t g = vaddq_f32(S_re, e); + const float32x4_t h = vaddq_f32(S_im, f); + vst1q_f32(&S->re[k], g); + vst1q_f32(&S->im[k], h); + } + } + limit = lim2; + X = &render_buffer_data[0]; + } while (j < lim2); + + H_j = &H[0]; + X = &render_buffer_data[render_buffer.Position()]; + j = 0; + limit = lim1; + do { + for (; j < limit; ++j, ++H_j, ++X) { + S->re[kFftLengthBy2] += X->re[kFftLengthBy2] * H_j->re[kFftLengthBy2] - + X->im[kFftLengthBy2] * H_j->im[kFftLengthBy2]; + S->im[kFftLengthBy2] += X->re[kFftLengthBy2] * H_j->im[kFftLengthBy2] + + X->im[kFftLengthBy2] * H_j->re[kFftLengthBy2]; + } + limit = lim2; + X = &render_buffer_data[0]; + } while (j < lim2); +} +#endif + #if defined(WEBRTC_ARCH_X86_FAMILY) // Produces the filter output (SSE2 variant). void ApplyFilter_SSE2(const RenderBuffer& render_buffer, @@ -305,6 +463,11 @@ void AdaptiveFirFilter::Filter(const RenderBuffer& render_buffer, case Aec3Optimization::kSse2: aec3::ApplyFilter_SSE2(render_buffer, H_, S); break; +#endif +#if defined(WEBRTC_HAS_NEON) + case Aec3Optimization::kNeon: + aec3::ApplyFilter_NEON(render_buffer, H_, S); + break; #endif default: aec3::ApplyFilter(render_buffer, H_, S); @@ -319,6 +482,11 @@ void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer, case Aec3Optimization::kSse2: aec3::AdaptPartitions_SSE2(render_buffer, G, H_); break; +#endif +#if defined(WEBRTC_HAS_NEON) + case Aec3Optimization::kNeon: + aec3::AdaptPartitions_NEON(render_buffer, G, H_); + break; #endif default: aec3::AdaptPartitions(render_buffer, G, H_); @@ -337,6 +505,12 @@ void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer, aec3::UpdateFrequencyResponse_SSE2(H_, &H2_); aec3::UpdateErlEstimator_SSE2(H2_, &erl_); break; +#endif +#if defined(WEBRTC_HAS_NEON) + case Aec3Optimization::kNeon: + aec3::UpdateFrequencyResponse_NEON(H_, &H2_); + aec3::UpdateErlEstimator_NEON(H2_, &erl_); + break; #endif default: aec3::UpdateFrequencyResponse(H_, &H2_); diff --git a/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h b/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h index 78b64225bd..75c418bf0f 100644 --- a/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h +++ b/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h @@ -29,6 +29,11 @@ namespace aec3 { void UpdateFrequencyResponse( rtc::ArrayView H, std::vector>* H2); +#if defined(WEBRTC_HAS_NEON) +void UpdateFrequencyResponse_NEON( + rtc::ArrayView H, + std::vector>* H2); +#endif #if defined(WEBRTC_ARCH_X86_FAMILY) void UpdateFrequencyResponse_SSE2( rtc::ArrayView H, @@ -40,6 +45,11 @@ void UpdateFrequencyResponse_SSE2( void UpdateErlEstimator( const std::vector>& H2, std::array* erl); +#if defined(WEBRTC_HAS_NEON) +void UpdateErlEstimator_NEON( + const std::vector>& H2, + std::array* erl); +#endif #if defined(WEBRTC_ARCH_X86_FAMILY) void UpdateErlEstimator_SSE2( const std::vector>& H2, @@ -50,6 +60,11 @@ void UpdateErlEstimator_SSE2( void AdaptPartitions(const RenderBuffer& render_buffer, const FftData& G, rtc::ArrayView H); +#if defined(WEBRTC_HAS_NEON) +void AdaptPartitions_NEON(const RenderBuffer& render_buffer, + const FftData& G, + rtc::ArrayView H); +#endif #if defined(WEBRTC_ARCH_X86_FAMILY) void AdaptPartitions_SSE2(const RenderBuffer& render_buffer, const FftData& G, @@ -60,6 +75,11 @@ void AdaptPartitions_SSE2(const RenderBuffer& render_buffer, void ApplyFilter(const RenderBuffer& render_buffer, rtc::ArrayView H, FftData* S); +#if defined(WEBRTC_HAS_NEON) +void ApplyFilter_NEON(const RenderBuffer& render_buffer, + rtc::ArrayView H, + FftData* S); +#endif #if defined(WEBRTC_ARCH_X86_FAMILY) void ApplyFilter_SSE2(const RenderBuffer& render_buffer, rtc::ArrayView H, diff --git a/webrtc/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc b/webrtc/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc index 6d1a5820fe..4560958bfd 100644 --- a/webrtc/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc @@ -10,6 +10,7 @@ #include "webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h" +#include #include #include #include @@ -41,10 +42,114 @@ std::string ProduceDebugText(size_t delay) { } // namespace +#if defined(WEBRTC_HAS_NEON) +// Verifies that the optimized methods for filter adaptation are similar to +// their reference counterparts. +TEST(AdaptiveFirFilter, FilterAdaptationNeonOptimizations) { + RenderBuffer render_buffer(Aec3Optimization::kNone, 3, 12, + std::vector(1, 12)); + Random random_generator(42U); + std::vector> x(3, std::vector(kBlockSize, 0.f)); + FftData S_C; + FftData S_NEON; + FftData G; + Aec3Fft fft; + std::vector H_C(10); + std::vector H_NEON(10); + for (auto& H_j : H_C) { + H_j.Clear(); + } + for (auto& H_j : H_NEON) { + H_j.Clear(); + } + + for (size_t k = 0; k < 30; ++k) { + RandomizeSampleVector(&random_generator, x[0]); + render_buffer.Insert(x); + } + + for (size_t j = 0; j < G.re.size(); ++j) { + G.re[j] = j / 10001.f; + } + for (size_t j = 1; j < G.im.size() - 1; ++j) { + G.im[j] = j / 20001.f; + } + G.im[0] = 0.f; + G.im[G.im.size() - 1] = 0.f; + + AdaptPartitions_NEON(render_buffer, G, H_NEON); + AdaptPartitions(render_buffer, G, H_C); + AdaptPartitions_NEON(render_buffer, G, H_NEON); + AdaptPartitions(render_buffer, G, H_C); + + for (size_t l = 0; l < H_C.size(); ++l) { + for (size_t j = 0; j < H_C[l].im.size(); ++j) { + EXPECT_NEAR(H_C[l].re[j], H_NEON[l].re[j], fabs(H_C[l].re[j] * 0.00001f)); + EXPECT_NEAR(H_C[l].im[j], H_NEON[l].im[j], fabs(H_C[l].im[j] * 0.00001f)); + } + } + + ApplyFilter_NEON(render_buffer, H_NEON, &S_NEON); + ApplyFilter(render_buffer, H_C, &S_C); + for (size_t j = 0; j < S_C.re.size(); ++j) { + EXPECT_NEAR(S_C.re[j], S_NEON.re[j], fabs(S_C.re[j] * 0.00001f)); + EXPECT_NEAR(S_C.im[j], S_NEON.im[j], fabs(S_C.re[j] * 0.00001f)); + } +} + +// Verifies that the optimized method for frequency response computation is +// bitexact to the reference counterpart. +TEST(AdaptiveFirFilter, UpdateFrequencyResponseNeonOptimization) { + const size_t kNumPartitions = 12; + std::vector H(kNumPartitions); + std::vector> H2(kNumPartitions); + std::vector> H2_NEON(kNumPartitions); + + for (size_t j = 0; j < H.size(); ++j) { + for (size_t k = 0; k < H[j].re.size(); ++k) { + H[j].re[k] = k + j / 3.f; + H[j].im[k] = j + k / 7.f; + } + } + + UpdateFrequencyResponse(H, &H2); + UpdateFrequencyResponse_NEON(H, &H2_NEON); + + for (size_t j = 0; j < H2.size(); ++j) { + for (size_t k = 0; k < H[j].re.size(); ++k) { + EXPECT_FLOAT_EQ(H2[j][k], H2_NEON[j][k]); + } + } +} + +// Verifies that the optimized method for echo return loss computation is +// bitexact to the reference counterpart. +TEST(AdaptiveFirFilter, UpdateErlNeonOptimization) { + const size_t kNumPartitions = 12; + std::vector> H2(kNumPartitions); + std::array erl; + std::array erl_NEON; + + for (size_t j = 0; j < H2.size(); ++j) { + for (size_t k = 0; k < H2[j].size(); ++k) { + H2[j][k] = k + j / 3.f; + } + } + + UpdateErlEstimator(H2, &erl); + UpdateErlEstimator_NEON(H2, &erl_NEON); + + for (size_t j = 0; j < erl.size(); ++j) { + EXPECT_FLOAT_EQ(erl[j], erl_NEON[j]); + } +} + +#endif + #if defined(WEBRTC_ARCH_X86_FAMILY) // Verifies that the optimized methods for filter adaptation are bitexact to // their reference counterparts. -TEST(AdaptiveFirFilter, FilterAdaptationOptimizations) { +TEST(AdaptiveFirFilter, FilterAdaptationSse2Optimizations) { bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0); if (use_sse2) { RenderBuffer render_buffer(Aec3Optimization::kNone, 3, 12, @@ -95,7 +200,7 @@ TEST(AdaptiveFirFilter, FilterAdaptationOptimizations) { // Verifies that the optimized method for frequency response computation is // bitexact to the reference counterpart. -TEST(AdaptiveFirFilter, UpdateFrequencyResponseOptimization) { +TEST(AdaptiveFirFilter, UpdateFrequencyResponseSse2Optimization) { bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0); if (use_sse2) { const size_t kNumPartitions = 12; @@ -123,7 +228,7 @@ TEST(AdaptiveFirFilter, UpdateFrequencyResponseOptimization) { // Verifies that the optimized method for echo return loss computation is // bitexact to the reference counterpart. -TEST(AdaptiveFirFilter, UpdateErlOptimization) { +TEST(AdaptiveFirFilter, UpdateErlSse2Optimization) { bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0); if (use_sse2) { const size_t kNumPartitions = 12; diff --git a/webrtc/modules/audio_processing/aec3/aec3_common.cc b/webrtc/modules/audio_processing/aec3/aec3_common.cc index da0f2c4f19..ae0624703b 100644 --- a/webrtc/modules/audio_processing/aec3/aec3_common.cc +++ b/webrtc/modules/audio_processing/aec3/aec3_common.cc @@ -21,6 +21,11 @@ Aec3Optimization DetectOptimization() { return Aec3Optimization::kSse2; } #endif + +#if defined(WEBRTC_HAS_NEON) + return Aec3Optimization::kNeon; +#endif + return Aec3Optimization::kNone; } diff --git a/webrtc/modules/audio_processing/aec3/aec3_common.h b/webrtc/modules/audio_processing/aec3/aec3_common.h index dbfc4ed0d6..3ea26c236f 100644 --- a/webrtc/modules/audio_processing/aec3/aec3_common.h +++ b/webrtc/modules/audio_processing/aec3/aec3_common.h @@ -24,7 +24,7 @@ namespace webrtc { #define ALIGN16_END __attribute__((aligned(16))) #endif -enum class Aec3Optimization { kNone, kSse2 }; +enum class Aec3Optimization { kNone, kSse2, kNeon }; constexpr int kNumBlocksPerSecond = 250; diff --git a/webrtc/modules/audio_processing/aec3/matched_filter.cc b/webrtc/modules/audio_processing/aec3/matched_filter.cc index 7bb5778999..4c6e0d7052 100644 --- a/webrtc/modules/audio_processing/aec3/matched_filter.cc +++ b/webrtc/modules/audio_processing/aec3/matched_filter.cc @@ -9,6 +9,9 @@ */ #include "webrtc/modules/audio_processing/aec3/matched_filter.h" +#if defined(WEBRTC_HAS_NEON) +#include +#endif #include "webrtc/typedefs.h" #if defined(WEBRTC_ARCH_X86_FAMILY) #include @@ -22,6 +25,114 @@ namespace webrtc { namespace aec3 { +#if defined(WEBRTC_HAS_NEON) + +void MatchedFilterCore_NEON(size_t x_start_index, + float x2_sum_threshold, + rtc::ArrayView x, + rtc::ArrayView y, + rtc::ArrayView h, + bool* filters_updated, + float* error_sum) { + const int h_size = static_cast(h.size()); + const int x_size = static_cast(x.size()); + RTC_DCHECK_EQ(0, h_size % 4); + + // Process for all samples in the sub-block. + for (size_t i = 0; i < kSubBlockSize; ++i) { + // Apply the matched filter as filter * x, and compute x * x. + + RTC_DCHECK_GT(x_size, x_start_index); + const float* x_p = &x[x_start_index]; + const float* h_p = &h[0]; + + // Initialize values for the accumulation. + float32x4_t s_128 = vdupq_n_f32(0); + float32x4_t x2_sum_128 = vdupq_n_f32(0); + float x2_sum = 0.f; + float s = 0; + + // Compute loop chunk sizes until, and after, the wraparound of the circular + // buffer for x. + const int chunk1 = + std::min(h_size, static_cast(x_size - x_start_index)); + + // Perform the loop in two chunks. + const int chunk2 = h_size - chunk1; + for (int limit : {chunk1, chunk2}) { + // Perform 128 bit vector operations. + const int limit_by_4 = limit >> 2; + for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { + // Load the data into 128 bit vectors. + const float32x4_t x_k = vld1q_f32(x_p); + const float32x4_t h_k = vld1q_f32(h_p); + // Compute and accumulate x * x and h * x. + x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k); + s_128 = vmlaq_f32(s_128, h_k, x_k); + } + + // Perform non-vector operations for any remaining items. + for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) { + const float x_k = *x_p; + x2_sum += x_k * x_k; + s += *h_p * x_k; + } + + x_p = &x[0]; + } + + // Combine the accumulated vector and scalar values. + float* v = reinterpret_cast(&x2_sum_128); + x2_sum += v[0] + v[1] + v[2] + v[3]; + v = reinterpret_cast(&s_128); + s += v[0] + v[1] + v[2] + v[3]; + + // Compute the matched filter error. + const float e = std::min(32767.f, std::max(-32768.f, y[i] - s)); + *error_sum += e * e; + + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = 0.7f * e / x2_sum; + const float32x4_t alpha_128 = vmovq_n_f32(alpha); + + // filter = filter + 0.7 * (y - filter * x) / x * x. + float* h_p = &h[0]; + x_p = &x[x_start_index]; + + // Perform the loop in two chunks. + for (int limit : {chunk1, chunk2}) { + // Perform 128 bit vector operations. + const int limit_by_4 = limit >> 2; + for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { + // Load the data into 128 bit vectors. + float32x4_t h_k = vld1q_f32(h_p); + const float32x4_t x_k = vld1q_f32(x_p); + // Compute h = h + alpha * x. + h_k = vmlaq_f32(h_k, alpha_128, x_k); + + // Store the result. + vst1q_f32(h_p, h_k); + } + + // Perform non-vector operations for any remaining items. + for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) { + *h_p += alpha * *x_p; + } + + x_p = &x[0]; + } + + *filters_updated = true; + } + + x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; + } +} + +#endif + #if defined(WEBRTC_ARCH_X86_FAMILY) void MatchedFilterCore_SSE2(size_t x_start_index, @@ -226,6 +337,13 @@ void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer, render_buffer.buffer, y, filters_[n], &filters_updated, &error_sum); break; +#endif +#if defined(WEBRTC_HAS_NEON) + case Aec3Optimization::kNeon: + aec3::MatchedFilterCore_NEON(x_start_index, x2_sum_threshold, + render_buffer.buffer, y, filters_[n], + &filters_updated, &error_sum); + break; #endif default: aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, diff --git a/webrtc/modules/audio_processing/aec3/matched_filter.h b/webrtc/modules/audio_processing/aec3/matched_filter.h index b9580c48f9..91df5c9d32 100644 --- a/webrtc/modules/audio_processing/aec3/matched_filter.h +++ b/webrtc/modules/audio_processing/aec3/matched_filter.h @@ -23,6 +23,19 @@ namespace webrtc { namespace aec3 { +#if defined(WEBRTC_HAS_NEON) + +// Filter core for the matched filter that is optimized for NEON. +void MatchedFilterCore_NEON(size_t x_start_index, + float x2_sum_threshold, + rtc::ArrayView x, + rtc::ArrayView y, + rtc::ArrayView h, + bool* filters_updated, + float* error_sum); + +#endif + #if defined(WEBRTC_ARCH_X86_FAMILY) // Filter core for the matched filter that is optimized for SSE2. diff --git a/webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc b/webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc index 45965c7bbb..02d91bb834 100644 --- a/webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc @@ -43,10 +43,47 @@ constexpr size_t kNumMatchedFilters = 4; } // namespace -#if defined(WEBRTC_ARCH_X86_FAMILY) -// Verifies that the optimized methods are bitexact to their reference +#if defined(WEBRTC_HAS_NEON) +// Verifies that the optimized methods for NEON are similar to their reference // counterparts. -TEST(MatchedFilter, TestOptimizations) { +TEST(MatchedFilter, TestNeonOptimizations) { + Random random_generator(42U); + std::vector x(2000); + RandomizeSampleVector(&random_generator, x); + std::vector y(kSubBlockSize); + std::vector h_NEON(512); + std::vector h(512); + int x_index = 0; + for (int k = 0; k < 1000; ++k) { + RandomizeSampleVector(&random_generator, y); + + bool filters_updated = false; + float error_sum = 0.f; + bool filters_updated_NEON = false; + float error_sum_NEON = 0.f; + + MatchedFilterCore_NEON(x_index, h.size() * 150.f * 150.f, x, y, h_NEON, + &filters_updated_NEON, &error_sum_NEON); + + MatchedFilterCore(x_index, h.size() * 150.f * 150.f, x, y, h, + &filters_updated, &error_sum); + + EXPECT_EQ(filters_updated, filters_updated_NEON); + EXPECT_NEAR(error_sum, error_sum_NEON, error_sum / 100000.f); + + for (size_t j = 0; j < h.size(); ++j) { + EXPECT_NEAR(h[j], h_NEON[j], 0.00001f); + } + + x_index = (x_index + kSubBlockSize) % x.size(); + } +} +#endif + +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Verifies that the optimized methods for SSE2 are bitexact to their reference +// counterparts. +TEST(MatchedFilter, TestSse2Optimizations) { bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0); if (use_sse2) { Random random_generator(42U); diff --git a/webrtc/modules/audio_processing/aec3/vector_math.h b/webrtc/modules/audio_processing/aec3/vector_math.h index afd4262b6b..b943f2bedd 100644 --- a/webrtc/modules/audio_processing/aec3/vector_math.h +++ b/webrtc/modules/audio_processing/aec3/vector_math.h @@ -12,6 +12,9 @@ #define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_ #include "webrtc/typedefs.h" +#if defined(WEBRTC_HAS_NEON) +#include +#endif #if defined(WEBRTC_ARCH_X86_FAMILY) #include #endif @@ -53,6 +56,51 @@ class VectorMath { } } break; #endif +#if defined(WEBRTC_HAS_NEON) + case Aec3Optimization::kNeon: { + const int x_size = static_cast(x.size()); + const int vector_limit = x_size >> 2; + + int j = 0; + for (; j < vector_limit * 4; j += 4) { + float32x4_t g = vld1q_f32(&x[j]); +#if !defined(WEBRTC_ARCH_ARM64) + float32x4_t y = vrsqrteq_f32(g); + + // Code to handle sqrt(0). + // If the input to sqrtf() is zero, a zero will be returned. + // If the input to vrsqrteq_f32() is zero, positive infinity is + // returned. + const uint32x4_t vec_p_inf = vdupq_n_u32(0x7F800000); + // check for divide by zero + const uint32x4_t div_by_zero = + vceqq_u32(vec_p_inf, vreinterpretq_u32_f32(y)); + // zero out the positive infinity results + y = vreinterpretq_f32_u32( + vandq_u32(vmvnq_u32(div_by_zero), vreinterpretq_u32_f32(y))); + // from arm documentation + // The Newton-Raphson iteration: + // y[n+1] = y[n] * (3 - d * (y[n] * y[n])) / 2) + // converges to (1/√d) if y0 is the result of VRSQRTE applied to d. + // + // Note: The precision did not improve after 2 iterations. + for (int i = 0; i < 2; i++) { + y = vmulq_f32(vrsqrtsq_f32(vmulq_f32(y, y), g), y); + } + // sqrt(g) = g * 1/sqrt(g) + g = vmulq_f32(g, y); +#else + g = vsqrtq_f32(g); +#endif + vst1q_f32(&x[j], g); + } + + for (; j < x_size; ++j) { + x[j] = sqrtf(x[j]); + } + } +#endif + break; default: std::for_each(x.begin(), x.end(), [](float& a) { a = sqrtf(a); }); } @@ -83,6 +131,24 @@ class VectorMath { } } break; #endif +#if defined(WEBRTC_HAS_NEON) + case Aec3Optimization::kNeon: { + const int x_size = static_cast(x.size()); + const int vector_limit = x_size >> 2; + + int j = 0; + for (; j < vector_limit * 4; j += 4) { + const float32x4_t x_j = vld1q_f32(&x[j]); + const float32x4_t y_j = vld1q_f32(&y[j]); + const float32x4_t z_j = vmulq_f32(x_j, y_j); + vst1q_f32(&z[j], z_j); + } + + for (; j < x_size; ++j) { + z[j] = x[j] * y[j]; + } + } break; +#endif default: std::transform(x.begin(), x.end(), y.begin(), z.begin(), std::multiplies()); @@ -111,6 +177,24 @@ class VectorMath { } } break; #endif +#if defined(WEBRTC_HAS_NEON) + case Aec3Optimization::kNeon: { + const int x_size = static_cast(x.size()); + const int vector_limit = x_size >> 2; + + int j = 0; + for (; j < vector_limit * 4; j += 4) { + const float32x4_t x_j = vld1q_f32(&x[j]); + float32x4_t z_j = vld1q_f32(&z[j]); + z_j = vaddq_f32(z_j, x_j); + vst1q_f32(&z[j], z_j); + } + + for (; j < x_size; ++j) { + z[j] += x[j]; + } + } break; +#endif default: std::transform(x.begin(), x.end(), z.begin(), z.begin(), std::plus()); diff --git a/webrtc/modules/audio_processing/aec3/vector_math_unittest.cc b/webrtc/modules/audio_processing/aec3/vector_math_unittest.cc index b40cf8d2a5..924ce31e88 100644 --- a/webrtc/modules/audio_processing/aec3/vector_math_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/vector_math_unittest.cc @@ -18,6 +18,65 @@ namespace webrtc { +#if defined(WEBRTC_HAS_NEON) + +TEST(VectorMath, Sqrt) { + std::array x; + std::array z; + std::array z_neon; + + for (size_t k = 0; k < x.size(); ++k) { + x[k] = (2.f / 3.f) * k; + } + + std::copy(x.begin(), x.end(), z.begin()); + aec3::VectorMath(Aec3Optimization::kNone).Sqrt(z); + std::copy(x.begin(), x.end(), z_neon.begin()); + aec3::VectorMath(Aec3Optimization::kNeon).Sqrt(z_neon); + for (size_t k = 0; k < z.size(); ++k) { + EXPECT_NEAR(z[k], z_neon[k], 0.0001f); + EXPECT_NEAR(sqrtf(x[k]), z_neon[k], 0.0001f); + } +} + +TEST(VectorMath, Multiply) { + std::array x; + std::array y; + std::array z; + std::array z_neon; + + for (size_t k = 0; k < x.size(); ++k) { + x[k] = k; + y[k] = (2.f / 3.f) * k; + } + + aec3::VectorMath(Aec3Optimization::kNone).Multiply(x, y, z); + aec3::VectorMath(Aec3Optimization::kNeon).Multiply(x, y, z_neon); + for (size_t k = 0; k < z.size(); ++k) { + EXPECT_FLOAT_EQ(z[k], z_neon[k]); + EXPECT_FLOAT_EQ(x[k] * y[k], z_neon[k]); + } +} + +TEST(VectorMath, Accumulate) { + std::array x; + std::array z; + std::array z_neon; + + for (size_t k = 0; k < x.size(); ++k) { + x[k] = k; + z[k] = z_neon[k] = 2.f * k; + } + + aec3::VectorMath(Aec3Optimization::kNone).Accumulate(x, z); + aec3::VectorMath(Aec3Optimization::kNeon).Accumulate(x, z_neon); + for (size_t k = 0; k < z.size(); ++k) { + EXPECT_FLOAT_EQ(z[k], z_neon[k]); + EXPECT_FLOAT_EQ(x[k] + 2.f * x[k], z_neon[k]); + } +} +#endif + #if defined(WEBRTC_ARCH_X86_FAMILY) TEST(VectorMath, Sqrt) {