From b6e840c0368f94366ad7f2c554d84e1d14f36ad4 Mon Sep 17 00:00:00 2001 From: Alessio Bazzica Date: Thu, 26 Nov 2020 14:29:46 +0100 Subject: [PATCH] RNN VAD: SSE2 optimization for `VectorMath::DotProduct` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug: webrtc:10480 Change-Id: I9f40352308bbfd5ea72a2607e7d1184cb6b85333 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/194328 Commit-Queue: Alessio Bazzica Reviewed-by: Per Ã…hgren Cr-Commit-Position: refs/heads/master@{#32745} --- .../agc2/rnn_vad/vector_math.h | 36 +++++++++++++++++-- .../agc2/rnn_vad/vector_math_unittest.cc | 7 ++-- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/modules/audio_processing/agc2/rnn_vad/vector_math.h b/modules/audio_processing/agc2/rnn_vad/vector_math.h index a989682bf3..51bbbfbd7e 100644 --- a/modules/audio_processing/agc2/rnn_vad/vector_math.h +++ b/modules/audio_processing/agc2/rnn_vad/vector_math.h @@ -11,6 +11,13 @@ #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_ #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_ +// Defines WEBRTC_ARCH_X86_FAMILY, used below. +#include "rtc_base/system/arch.h" + +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include +#endif + #include #include "api/array_view.h" @@ -31,14 +38,39 @@ class VectorMath { // Computes the dot product between two equally sized vectors. float DotProduct(rtc::ArrayView x, rtc::ArrayView y) const { + RTC_DCHECK_EQ(x.size(), y.size()); #if defined(WEBRTC_ARCH_X86_FAMILY) if (cpu_features_.avx2) { return DotProductAvx2(x, y); + } else if (cpu_features_.sse2) { + __m128 accumulator = _mm_setzero_ps(); + constexpr int kBlockSizeLog2 = 2; + constexpr int kBlockSize = 1 << kBlockSizeLog2; + const int incomplete_block_index = (x.size() >> kBlockSizeLog2) + << kBlockSizeLog2; + for (int i = 0; i < incomplete_block_index; i += kBlockSize) { + RTC_DCHECK_LE(i + kBlockSize, x.size()); + const __m128 x_i = _mm_loadu_ps(&x[i]); + const __m128 y_i = _mm_loadu_ps(&y[i]); + // Multiply-add. + const __m128 z_j = _mm_mul_ps(x_i, y_i); + accumulator = _mm_add_ps(accumulator, z_j); + } + // Reduce `accumulator` by addition. + __m128 high = _mm_movehl_ps(accumulator, accumulator); + accumulator = _mm_add_ps(accumulator, high); + high = _mm_shuffle_ps(accumulator, accumulator, 1); + accumulator = _mm_add_ps(accumulator, high); + float dot_product = _mm_cvtss_f32(accumulator); + // Add the result for the last block if incomplete. + for (int i = incomplete_block_index; static_cast(i) < x.size(); + ++i) { + dot_product += x[i] * y[i]; + } + return dot_product; } - // TODO(bugs.webrtc.org/10480): Add SSE2 alternative implementation. #endif // TODO(bugs.webrtc.org/10480): Add NEON alternative implementation. - RTC_DCHECK_EQ(x.size(), y.size()); return std::inner_product(x.begin(), x.end(), y.begin(), 0.f); } diff --git a/modules/audio_processing/agc2/rnn_vad/vector_math_unittest.cc b/modules/audio_processing/agc2/rnn_vad/vector_math_unittest.cc index 19a8af0cab..9a2d5bc116 100644 --- a/modules/audio_processing/agc2/rnn_vad/vector_math_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/vector_math_unittest.cc @@ -47,9 +47,10 @@ std::vector GetCpuFeaturesToTest() { v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false}); AvailableCpuFeatures available = GetAvailableCpuFeatures(); if (available.avx2) { - AvailableCpuFeatures features( - {/*sse2=*/false, /*avx2=*/true, /*neon=*/false}); - v.push_back(features); + v.push_back({/*sse2=*/false, /*avx2=*/true, /*neon=*/false}); + } + if (available.sse2) { + v.push_back({/*sse2=*/true, /*avx2=*/false, /*neon=*/false}); } return v; }