diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index a4285bab5a..fafea4294c 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -78,6 +78,35 @@ rtc_library("rnn_vad_lp_residual") { ] } +rtc_source_set("vector_math") { + sources = [ "vector_math.h" ] + deps = [ + "..:cpu_features", + "../../../../api:array_view", + "../../../../rtc_base:checks", + "../../../../rtc_base/system:arch", + ] +} + +if (current_cpu == "x86" || current_cpu == "x64") { + rtc_library("vector_math_avx2") { + sources = [ "vector_math_avx2.cc" ] + if (is_win) { + cflags = [ "/arch:AVX2" ] + } else { + cflags = [ + "-mavx2", + "-mfma", + ] + } + deps = [ + ":vector_math", + "../../../../api:array_view", + "../../../../rtc_base:checks", + ] + } +} + rtc_library("rnn_vad_pitch") { sources = [ "pitch_search.cc", @@ -88,6 +117,7 @@ rtc_library("rnn_vad_pitch") { deps = [ ":rnn_vad_auto_correlation", ":rnn_vad_common", + ":vector_math", "..:cpu_features", "../../../../api:array_view", "../../../../rtc_base:checks", @@ -95,6 +125,9 @@ rtc_library("rnn_vad_pitch") { "../../../../rtc_base:safe_compare", "../../../../rtc_base:safe_conversions", ] + if (current_cpu == "x86" || current_cpu == "x64") { + deps += [ ":vector_math_avx2" ] + } } rtc_source_set("rnn_vad_ring_buffer") { @@ -191,6 +224,7 @@ if (rtc_include_tests) { "spectral_features_internal_unittest.cc", "spectral_features_unittest.cc", "symmetric_matrix_buffer_unittest.cc", + "vector_math_unittest.cc", ] deps = [ ":rnn_vad", @@ -203,6 +237,7 @@ if (rtc_include_tests) { ":rnn_vad_spectral_features", ":rnn_vad_symmetric_matrix_buffer", ":test_utils", + ":vector_math", "..:cpu_features", "../..:audioproc_test_utils", "../../../../api:array_view", @@ -216,6 +251,9 @@ if (rtc_include_tests) { "../../utility:pffft_wrapper", "//third_party/rnnoise:rnn_vad", ] + if (current_cpu == "x86" || current_cpu == "x64") { + deps += [ ":vector_math_avx2" ] + } absl_deps = [ "//third_party/abseil-cpp/absl/memory" ] data = unittest_resources if (is_ios) { diff --git a/modules/audio_processing/agc2/rnn_vad/vector_math.h b/modules/audio_processing/agc2/rnn_vad/vector_math.h new file mode 100644 index 0000000000..a989682bf3 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/vector_math.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_ + +#include + +#include "api/array_view.h" +#include "modules/audio_processing/agc2/cpu_features.h" +#include "rtc_base/checks.h" +#include "rtc_base/system/arch.h" + +namespace webrtc { +namespace rnn_vad { + +// Provides optimizations for mathematical operations having vectors as +// operand(s). +class VectorMath { + public: + explicit VectorMath(AvailableCpuFeatures cpu_features) + : cpu_features_(cpu_features) {} + + // Computes the dot product between two equally sized vectors. + float DotProduct(rtc::ArrayView x, + rtc::ArrayView y) const { +#if defined(WEBRTC_ARCH_X86_FAMILY) + if (cpu_features_.avx2) { + return DotProductAvx2(x, y); + } + // TODO(bugs.webrtc.org/10480): Add SSE2 alternative implementation. +#endif + // TODO(bugs.webrtc.org/10480): Add NEON alternative implementation. + RTC_DCHECK_EQ(x.size(), y.size()); + return std::inner_product(x.begin(), x.end(), y.begin(), 0.f); + } + + private: + float DotProductAvx2(rtc::ArrayView x, + rtc::ArrayView y) const; + + const AvailableCpuFeatures cpu_features_; +}; + +} // namespace rnn_vad +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_ diff --git a/modules/audio_processing/agc2/rnn_vad/vector_math_avx2.cc b/modules/audio_processing/agc2/rnn_vad/vector_math_avx2.cc new file mode 100644 index 0000000000..3b2c4ade03 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/vector_math_avx2.cc @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/vector_math.h" + +#include + +#include "api/array_view.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace rnn_vad { + +float VectorMath::DotProductAvx2(rtc::ArrayView x, + rtc::ArrayView y) const { + RTC_DCHECK(cpu_features_.avx2); + RTC_DCHECK_EQ(x.size(), y.size()); + __m256 accumulator = _mm256_setzero_ps(); + constexpr int kBlockSizeLog2 = 3; + constexpr int kBlockSize = 1 << kBlockSizeLog2; + const int incomplete_block_index = (x.size() >> kBlockSizeLog2) + << kBlockSizeLog2; + for (int i = 0; i < incomplete_block_index; i += kBlockSize) { + RTC_DCHECK_LE(i + kBlockSize, x.size()); + const __m256 x_i = _mm256_loadu_ps(&x[i]); + const __m256 y_i = _mm256_loadu_ps(&y[i]); + accumulator = _mm256_fmadd_ps(x_i, y_i, accumulator); + } + // Reduce `accumulator` by addition. + __m128 high = _mm256_extractf128_ps(accumulator, 1); + __m128 low = _mm256_extractf128_ps(accumulator, 0); + low = _mm_add_ps(high, low); + high = _mm_movehl_ps(high, low); + low = _mm_add_ps(high, low); + high = _mm_shuffle_ps(low, low, 1); + low = _mm_add_ss(high, low); + float dot_product = _mm_cvtss_f32(low); + // Add the result for the last block if incomplete. + for (int i = incomplete_block_index; static_cast(i) < x.size(); ++i) { + dot_product += x[i] * y[i]; + } + return dot_product; +} + +} // namespace rnn_vad +} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/vector_math_unittest.cc b/modules/audio_processing/agc2/rnn_vad/vector_math_unittest.cc new file mode 100644 index 0000000000..19a8af0cab --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/vector_math_unittest.cc @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/vector_math.h" + +#include + +#include "modules/audio_processing/agc2/cpu_features.h" +#include "test/gtest.h" + +namespace webrtc { +namespace rnn_vad { +namespace { + +constexpr int kSizeOfX = 19; +constexpr float kX[kSizeOfX] = { + 0.31593041f, 0.9350786f, -0.25252445f, -0.86956251f, -0.9673632f, + 0.54571901f, -0.72504495f, -0.79509912f, -0.25525012f, -0.73340473f, + 0.15747377f, -0.04370565f, 0.76135145f, -0.57239645f, 0.68616848f, + 0.3740298f, 0.34710799f, -0.92207423f, 0.10738454f}; +constexpr int kSizeOfXSubSpan = 16; +static_assert(kSizeOfXSubSpan < kSizeOfX, ""); +constexpr float kEnergyOfX = 7.315563958160327f; +constexpr float kEnergyOfXSubspan = 6.333327669592963f; + +class VectorMathParametrization + : public ::testing::TestWithParam {}; + +TEST_P(VectorMathParametrization, TestDotProduct) { + VectorMath vector_math(/*cpu_features=*/GetParam()); + EXPECT_FLOAT_EQ(vector_math.DotProduct(kX, kX), kEnergyOfX); + EXPECT_FLOAT_EQ( + vector_math.DotProduct({kX, kSizeOfXSubSpan}, {kX, kSizeOfXSubSpan}), + kEnergyOfXSubspan); +} + +// Finds the relevant CPU features combinations to test. +std::vector GetCpuFeaturesToTest() { + std::vector v; + v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false}); + AvailableCpuFeatures available = GetAvailableCpuFeatures(); + if (available.avx2) { + AvailableCpuFeatures features( + {/*sse2=*/false, /*avx2=*/true, /*neon=*/false}); + v.push_back(features); + } + return v; +} + +INSTANTIATE_TEST_SUITE_P( + RnnVadTest, + VectorMathParametrization, + ::testing::ValuesIn(GetCpuFeaturesToTest()), + [](const ::testing::TestParamInfo& info) { + return info.param.ToString(); + }); + +} // namespace +} // namespace rnn_vad +} // namespace webrtc